{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "name": "GPTNeo_example_notebook.ipynb",
      "provenance": [],
      "collapsed_sections": [],
      "toc_visible": true
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "accelerator": "TPU"
  },
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "J0i5MRP0SV8D"
      },
      "source": [
        "Welcome to the colab notebook for [GPTNeo](https://github.com/EleutherAI/GPTNeo) - a fully open source implementation of GPT like models for mesh-tensorflow by [EleutherAI](eleuther.ai).\n",
        "\n",
        "Our library provides training and inference for GPT models up to GPT3 sizes on both TPUs and GPUs. \n",
        "\n",
        "In this notebook we walk you through TPU training (or finetuning!) and sampling using the freely available colab TPUs.\n",
        "\n",
        "If you find our repo useful, come join [our discord](https://discord.gg/BK2v3EJ) and say hi! 😬\n",
        "\n",
        "Before we get going - make sure you are running this notebook with a TPU available. Go to Runtime -> Change Runtime Type and select 'TPU' under hardware accelerator.\n",
        "\n",
        "\n"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "K-53qkZV6Lv9",
        "cellView": "form"
      },
      "source": [
        "#@title Setup\n",
        "%tensorflow_version 2.x\n",
        "!git clone https://github.com/EleutherAI/GPTNeo\n",
        "%cd GPTNeo\n",
        "!pip3 install -q -r requirements.txt\n",
        "pretrained_model = None\n",
        "dataset = None\n"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "M0R1owh2qvp8"
      },
      "source": [
        "## Set Up Google Cloud"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "0PmzM4dy7diP"
      },
      "source": [
        "To train on TPUs we need to store our data on a google cloud bucket - as TPUs can't read from local filesystems.\n",
        "\n",
        "You can set up a bucket by signing up for a free trial here: https://console.cloud.google.com/\n",
        "\n",
        "Make a bucket at https://console.cloud.google.com/storage and come back when that's done.\n",
        "\n",
        "Make sure to select 'Uniform' access control when setting up the bucket, or the colab notebook won't have the required permissions to read from it.\n",
        "\n",
        "The next cell sets up google authentication and gives the notebook read and write access to your bucket.\n"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "71bQUjPA7qvj"
      },
      "source": [
        "from google.colab import auth\n",
        "auth.authenticate_user()\n",
        "!gcloud init"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "Cr_c6A2NBK5i",
        "cellView": "form"
      },
      "source": [
        "path_to_cloud_bucket = 'gs://your-cloud-bucket/' #@param {type:\"string\"}"
      ],
      "execution_count": 3,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "EZGbzUPD0tad"
      },
      "source": [
        "## Set Up Dataset"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "R918l14UhrBR"
      },
      "source": [
        "We first need to download and tokenize a dataset. If you just want to sample from a pretrained model, you can skip this step and move on to the `Pretrained Model` section.\n",
        "\n",
        "You can choose from:\n",
        "\n",
        "*   Sampling Only - choose this option if you only wish to sample from our trained models, then move on to the `Pretrained Model` section.\n",
        "\n",
        "*   OpenWebText - an opensource clone of OpenAI's WebText dataset, the original training data of GPT2.\n",
        "\n",
        "*   YoutubeSubtitles - a dataset of subtitles scraped from youtube videos.\n",
        "\n",
        "* Hackernews - comments scraped from hackernews\n",
        "\n",
        "* NIHExporter - Data relating to various projects from the national institute of health.\n",
        "\n",
        "* Custom - if this option is chosen you will be prompted to enter the path to your own dataset. It should be a directory containing .txt or .jsonl files.\n",
        "\n",
        "All these datasets are from EleutherAI's side project - [The Pile™](https://github.com/EleutherAI/The-Pile) - an effort to gather a general purpose, diverse and open source plain text dataset large enough to train 1T+ parameter language models.\n",
        "\n",
        "Even the smallest datasets are fairly large files, so this step will likely take a while. Select a dataset in the next cell, then run the next two cells, and go grab a snack and a cup of tea 😊\n",
        "\n",
        "Alternatively, you can provide your own dataset in the form of a folder or gzip archive of .txt files. Simply select 'Custom' below and follow input the path to your data and the name of your dataset when prompted."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "pM8jP3Am_hsx",
        "cellView": "form"
      },
      "source": [
        "# Select a Dataset:\n",
        "import os\n",
        "dataset = 'Sampling_Only' #@param [\"Sampling_Only\", \"OpenWebText\", \"YoutubeSubtitles\", \"HackerNews\", \"NIHExporter\", \"Custom\"]\n",
        "\n",
        "if dataset == \"Sampling_Only\":\n",
        "  pass\n",
        "elif dataset == 'OpenWebText':\n",
        "  !wget https://the-eye.eu/public/AI/pile_preliminary_components/openwebtext2.jsonl.zst.tar -O openwebtext.tar.xz\n",
        "  !tar xf openwebtext.tar.xz\n",
        "  dataset_path = \"openwebtext\"\n",
        "  dataset_name = dataset_path\n",
        "  out_name = dataset_name + \"_tokenized\"\n",
        "elif dataset == 'YoutubeSubtitles':\n",
        "  os.makedirs('data', exist_ok=True)\n",
        "  !wget https://the-eye.eu/public/AI/pile_preliminary_components/yt_subs.jsonl.zst -O data/yt_subs.jsonl.zst\n",
        "  dataset_path = 'data'\n",
        "  dataset_name = 'ytsubs'\n",
        "  out_name = dataset_name + \"_tokenized\"\n",
        "elif dataset == 'HackerNews':\n",
        "  os.makedirs('data', exist_ok=True)\n",
        "  !wget https://the-eye.eu/public/AI/pile_preliminary_components/hn.tar.gz -O data/hn.tar.gz\n",
        "  dataset_path = 'data'\n",
        "  dataset_name = 'hackernews'\n",
        "  out_name = dataset_name + \"_tokenized\"\n",
        "elif dataset == \"NIHExporter\":\n",
        "  os.makedirs('data', exist_ok=True)\n",
        "  !wget https://the-eye.eu/public/AI/pile_preliminary_components/NIH_ExPORTER_awarded_grant_text.jsonl.zst -O data/NIH_ExPORTER_awarded_grant_text.jsonl.zst\n",
        "  dataset_path = 'data'\n",
        "  os.system('mv NIH_ExPORTER_awarded_grant_text.jsonl.zst ./data')\n",
        "  dataset_name = 'nihexporter'\n",
        "  out_name = dataset_name + \"_tokenized\"\n",
        "elif dataset == \"Custom\":\n",
        "  dataset_path = input('Enter the path to the folder containing your data: ')\n",
        "  dataset_name = input('Enter the name of your dataset: ')\n",
        "  out_name = dataset_name + \"_tokenized\"\n",
        "else:\n",
        "  raise NotImplementedError('please select from available options: [\"OpenWebText\", \"YoutubeSubtitles\", \"HackerNews\", \"NIHExporter\", \"Custom\"]')\n"
      ],
      "execution_count": 4,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "zMl1cHtN5I_W"
      },
      "source": [
        "### Tokenize and Upload Data"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "6IBIompTJaqm"
      },
      "source": [
        "Now tokenize the dataset and copy it over to your google cloud bucket. You may skip this step if you are sampling from a pre-trained model."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "Pq5u0WUSJWwz",
        "cellView": "both"
      },
      "source": [
        "# Tokenize Data\n",
        "!python data/create_tfrecords.py --input_dir /content/GPTNeo/$dataset_path --name $dataset_name --files_per 1000 --output_dir $out_name --write_dataset_config --processes 1\n",
        "\n",
        "# copy the data to your bucket\n",
        "if not path_to_cloud_bucket.endswith('/'):\n",
        "  path_to_cloud_bucket += '/'\n",
        "copy_loc = path_to_cloud_bucket + \"datasets/\" + dataset\n",
        "!gsutil -m cp -r /content/GPTNeo/$out_name $copy_loc\n",
        "!gsutil ls $path_to_cloud_bucket"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "NhvmTFD7b_fb"
      },
      "source": [
        "Before starting training - you'll need to edit your dataset & model configs to point to your buckets / data. You need to do this even if you are sampling from a pre-trained model.\n",
        "\n",
        "*   First change the writefile path to point to your chosen dataset - e.g `%%writefile configs/dataset_configs/ytsubs.json`\n",
        "*   Change the \"path\" field to point to your cloud bucket location - e.g `gs://neo_lmdatasets/datasets/ytsubs_*.tfrecords`\n",
        "* Change `dataset_name` in `%%writefile configs/dataset_configs/dataset_name.json` to the name of your chosen dataset.\n",
        "* Once you've made the edits, then run the cell below to overwrite the existing files.\n",
        "\n",
        "\n"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "MCsZP48vavCP"
      },
      "source": [
        "%%writefile configs/dataset_configs/Sampling_Only.json\n",
        "\n",
        "{\n",
        "  \"path\": \"gs://eleutherai/datasets/Sampling_Only/Sampling_Only*.tfrecords\",\n",
        "  \"eval_path\": \"\",\n",
        "  \"n_vocab\": 50256,\n",
        "  \"tokenizer_is_pretrained\": true,\n",
        "  \"tokenizer_path\": \"gpt2\",\n",
        "  \"eos_id\": 50256,\n",
        "  \"padding_id\": 50257\n",
        "}\n"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "dH0x3dI9j85P"
      },
      "source": [
        "## Set Model Configs"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "I6GnCgAkB7GQ"
      },
      "source": [
        "The model below is identical to our pretrained GPT3XL model (1.3B Params). \n",
        "\n",
        "If you want to use a smaller model, you can modify any of the config files in ../configs/ ending in _8.json, all of which are designed to train on tpu-v8s.\n",
        "\n",
        "For a more detailed breakdown on what each item in the configuration file means - please read through our training and config guides in our [github README](https://github.com/EleutherAI/GPTNeo#training-guide). \n",
        "\n",
        "You'll want to change the first item in the `datasets` list to the name of your chosen dataset. (the filename minus .json in ./configs/dataset_configs)\n",
        "\n",
        "You'll also want to modify the `model_path` field to point to your google cloud bucket, so checkpoints get saved to there."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "L9hUDdokiWj6"
      },
      "source": [
        "%%writefile configs/GPT3_XL.json\n",
        "\n",
        "{\n",
        "    \"n_head\": 16,\n",
        "    \"n_vocab\": 50257,\n",
        "    \"embed_dropout\": 0,\n",
        "    \"lr\": 0.0002,\n",
        "    \"lr_decay\": \"cosine\",\n",
        "    \"warmup_steps\": 3000,\n",
        "    \"beta1\": 0.9,\n",
        "    \"beta2\": 0.95,\n",
        "    \"epsilon\": 1e-8,\n",
        "    \"opt_name\": \"adam\",\n",
        "    \"weight_decay\": 0,\n",
        "    \"train_batch_size\": 256,\n",
        "    \"attn_dropout\": 0,\n",
        "    \"train_steps\": 600000,\n",
        "    \"eval_steps\": 0,\n",
        "    \"predict_steps\": 1,\n",
        "    \"res_dropout\": 0,\n",
        "    \"eval_batch_size\": 4,\n",
        "    \"predict_batch_size\": 1,\n",
        "    \"iterations\": 100,\n",
        "    \"n_embd\": 2048,\n",
        "    \"datasets\": [[\"pile\", null, null, null]],\n",
        "    \"model\": \"GPT\",\n",
        "    \"model_path\": \"gs://eleutherai/GPT3_XL\",\n",
        "    \"n_ctx\": 2048,\n",
        "    \"n_layer\": 24,\n",
        "    \"scale_by_depth\": true,\n",
        "    \"scale_by_in\": false,\n",
        "    \"attention_types\" :  [[[\"global\", \"local\"],12]],\n",
        "    \"mesh_shape\": \"x:4,y:2\",\n",
        "    \"layout\": \"intermediate_expanded:x,heads:x,vocab:n_vocab,memory_length:y,embd:y\",\n",
        "    \"activation_function\": \"gelu\",\n",
        "    \"recompute_grad\": true,\n",
        "    \"gradient_clipping\": 1.0,\n",
        "    \"tokens_per_mb_per_replica\": 2048,\n",
        "    \"precision\": \"bfloat16\"\n",
        "}"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "GWK9MJqwcXKn"
      },
      "source": [
        "## Training from Scratch\n",
        "\n",
        "Now we will begin to train the model. If no previous model is found in \"model_path\", the model will start training from scratch. If you'd prefer to finetune from pretrained, skip to the `Finetune a Pretrained Model` section.\n",
        "\n",
        "If everything's set up correctly, you can now run the main.py function to start training!"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "VUtrysOSBzjJ"
      },
      "source": [
        "!python3 main.py --model colab_XL --steps_per_checkpoint 500 --tpu colab"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "koKQHA5ikCvD"
      },
      "source": [
        "## Pretrained Model"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "0QZv4_pnkk26"
      },
      "source": [
        "If you want to sample from or finetune a pretrained model, EleutherAI has pretrained two models for release. One with [1.3B parameters](https://the-eye.eu/public/AI/gptneo-release/GPT3_XL/), and another with [2.7B](https://the-eye.eu/public/AI/gptneo-release/GPT3_2-7B/). \n",
        "\n",
        "Select an option below to download the weights locally. You will then need to upload them to your cloud bucket in order to finetune from them. If the download command isn't working, try the commented out code to download from a different source.\n",
        "\n",
        "The 2-7B model likely won't fit into the colab TPUs memory, and you may have to get some larger pods to finetune from it.\n",
        "\n",
        "Sampling from it, however, works just fine.\n"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "lgTG1ammqGB0",
        "cellView": "form"
      },
      "source": [
        "# @title Download pretrained model weights:\n",
        "pretrained_model = 'GPT3_2-7B' #@param [\"GPT3_XL\", \"GPT3_2-7B\"]\n",
        "!wget -m -np -c -U \"eye02\" -w 2 -R \"index.html*\" \"https://the-eye.eu/public/AI/gptneo-release/$pretrained_model/\"\n",
        "path_to_local_weights = f\"/content/GPTNeo/the-eye.eu/public/AI/gptneo-release/{pretrained_model}\"\n",
        "\n",
        "# URL = f\"http://eaidata.bmk.sh/data/gptneo-release/{pretrained_model}/\"\n",
        "# FOLDER_NAME = \"GPT3_XL\"\n",
        "# !curl $URL | grep -i \"</a>\" | sed -n 's/.*href=\"\\([^\"]*\\).*/\\1/p' | sed \"s|^|$URL|\" | xargs -n 1 -P 4 wget -P $pretrained_model\n",
        "# path_to_local_weights = pretrained_model\n"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "GU3BDNJN_ZXE"
      },
      "source": [
        "# upload to your bucket\n",
        "bucket_base = \"gs://\" + path_to_cloud_bucket.replace('gs://', '').split('/')[0]\n",
        "!gsutil -m cp -r $path_to_local_weights $bucket_base"
      ],
      "execution_count": 9,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "bnqkKBTOn0ox"
      },
      "source": [
        "If everything has worked successfully you should now see your model listed in your bucket below."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "80t9MMionm2h"
      },
      "source": [
        "!gsutil ls $bucket_base"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "QDKL8fCSoApL"
      },
      "source": [
        "Now we want to make a few modifications to the model config in order to get training / sampling working on colab.\n",
        "\n",
        "If you are just sampling from our pretrained models, you can leave the settings as is, run the cell below, then move on to the `Sample from your model` section.\n",
        "\n",
        "If finetuning, you can change parameters below. \n",
        "\n",
        "* `path_to_model` should point to the model weights location in your cloud bucket, and will default to `$bucket_base/${pretrained_model}` if nothing is entered.\n",
        "\n",
        "* `batch_size` is your train batch size - if you're encountering memory errors, try lowering this.\n",
        "\n",
        "* `dataset_name` is the name of your dataset, if nothing is entered, this should default to the dataset you selected in the `Prepare Data` section.\n",
        "\n",
        "* `mesh_shape` specifies the way the model will be divided up across the TPU cores. We suggest leaving this alone unless you know what you're doing.\n",
        "\n",
        "* `train_steps` specifies how many steps you want the model to finetune for. We set this to 1000 for demonstrative purposes but you may need to increase this a little depending on your goals. If you are just sampling from the model, you can leave this as is.\n",
        "\n",
        "* `steps_per_checkpoint` specifies how often you want to save model weights during training.\n",
        "\n"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "Laf0slBMDCUj",
        "cellView": "form"
      },
      "source": [
        "# @title Modify config for colab. \n",
        "  \n",
        "import json\n",
        "from pprint import pprint\n",
        "\n",
        "path_to_model = \"\" #@param {type:\"string\"}\n",
        "batch_size = 8 #@param {type:\"integer\"}\n",
        "dset = \"\"  #@param {type:\"string\"}\n",
        "mesh_shape = \"x:4,y:2\" #@param {type:\"string\"}\n",
        "train_steps = 1000 #@param {type:\"integer\"}\n",
        "steps_per_checkpoint = 500 #@param {type:\"integer\"}\n",
        "start_step = 400000 if pretrained_model == \"GPT3_2-7B\" else 362000\n",
        "\n",
        "if path_to_model == \"\":\n",
        "  path_to_model = f'{bucket_base.strip(\"/\")}/{pretrained_model}'\n",
        "print(f'MODEL PATH: {path_to_model}\\n')\n",
        "\n",
        "if dset == \"\" and dataset != \"Sampling_Only\":\n",
        "  dset = dataset\n",
        "elif dataset is None and dset == \"\":\n",
        "  dset = \"pile\"\n",
        "\n",
        "def pad_to_multiple_of(n, mult):\n",
        "  \"\"\"\n",
        "  pads n to a multiple of mult\n",
        "  \"\"\"\n",
        "  extra = n % mult\n",
        "  if extra > 0:\n",
        "      n = n + mult - extra\n",
        "  return n\n",
        "\n",
        "with open(f'{path_to_local_weights}/config.json', 'r') as f:\n",
        "  data = json.load(f)\n",
        "  pprint(data)\n",
        "  dset_val = [[dset, None, None, None]] if dset != \"\" else data[\"datasets\"]\n",
        "  mods = {\n",
        "          \"mesh_shape\": mesh_shape,\n",
        "          \"layout\": \"intermediate_expanded:x,heads:x,memory_length:y,embd:y\",\n",
        "          \"model_path\": path_to_model,\n",
        "          \"datasets\": dset_val,\n",
        "          \"train_steps\": start_step + train_steps,\n",
        "          \"eval_steps\": 0,\n",
        "          \"train_batch_size\": batch_size,\n",
        "          \"predict_batch_size\": batch_size\n",
        "        }\n",
        "  data.update(mods)\n",
        "  print('\\n--->\\n')\n",
        "  pprint(data)\n",
        "  with open(f'configs/{pretrained_model}.json', 'w') as outfile:\n",
        "    json.dump(data, outfile, indent=2)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "fPwwbPCA6O7r"
      },
      "source": [
        "### Begin Fine-Tuning\n",
        "\n",
        "If you are fine-tuning the pretrained model, this line of code will begin the training."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "0YlaHzyXuMaj"
      },
      "source": [
        "!python3 main.py --model $pretrained_model --steps_per_checkpoint $steps_per_checkpoint --tpu colab"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "I_HxtEmBGTGT"
      },
      "source": [
        "### Sample from your model\n",
        "\n",
        "Once training is finished, (or your pretrained model is on your bucket), you can run the same command with the --predict flag to sample from your model.\n",
        "\n",
        "To pass in a prompt, save it to a .txt file, and pass in the name of the file with the --prompt flag.\n",
        "\n",
        "use the cell below to enter your prompt, and run it to save it to example_prompt.txt.\n",
        "\n",
        "You may need to decrease the predict batch size in your config if you're facing OOM errors.\n",
        "\n",
        "Let's see if the GPTNeo model can finish coding itself, with a sample prompt consisting of the beginning of a `torch.nn.Module`:"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "CQE1Y5wPFx7h",
        "outputId": "e1a92c0c-18ee-4014-a0b8-d67161384940",
        "colab": {
          "base_uri": "https://localhost:8080/"
        }
      },
      "source": [
        "%%writefile example_prompt.txt\n",
        "\n",
        "class GPT(nn.Module):\n",
        "    \"\"\"  the full GPT language model, with a context size of block_size \"\"\"\n",
        "\n",
        "    def __init__(self, config):\n",
        "        super().__init__()\n",
        "\n",
        "        # input embedding stem\n",
        "        self.tok_emb = nn.Embedding(config.vocab_size, config.n_embd)\n",
        "        self.pos_emb = nn.Parameter(torch.zeros(1, config.block_size, config.n_embd))\n",
        "        self.drop = nn.Dropout(config.embd_pdrop)\n",
        "        # transformer\n",
        "        self.blocks = nn.Sequential(*[Block(config) for _ in range(config.n_layer)])\n",
        "        # decoder head\n",
        "        self.ln_f = nn.LayerNorm(config.n_embd)\n",
        "        self.head = nn.Linear(config.n_embd, config.vocab_size, bias=False)\n",
        "\n",
        "        self.block_size = config.block_size\n",
        "        self.apply(self._init_weights)\n",
        "\n",
        "        logger.info(\"number of parameters: %e\", sum(p.numel() for p in self.parameters()))"
      ],
      "execution_count": 13,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "Overwriting example_prompt.txt\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "sf_5E4fHFQIh",
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "f3c12a94-7ef8-43c1-a668-6365966d42b4"
      },
      "source": [
        "!python3 main.py --model $pretrained_model --steps_per_checkpoint 500 --tpu colab --predict --prompt example_prompt.txt"
      ],
      "execution_count": 14,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "2021-03-22 12:20:43.411018: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcudart.so.11.0\n",
            "WARNING:tensorflow:From /usr/local/lib/python3.7/dist-packages/tensorflow/python/compat/v2_compat.py:96: disable_resource_variables (from tensorflow.python.ops.variable_scope) is deprecated and will be removed in a future version.\n",
            "Instructions for updating:\n",
            "non-resource variables are not supported in the long term\n",
            "Current step 400000\n",
            "Saving config to gs://test-bucket-neo/GPT3_2-7B\n",
            "2021-03-22 12:20:50.689547: I tensorflow/compiler/jit/xla_gpu_device.cc:99] Not creating XLA devices, tf_xla_enable_xla_devices not set\n",
            "2021-03-22 12:20:50.691059: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcuda.so.1\n",
            "2021-03-22 12:20:50.701975: E tensorflow/stream_executor/cuda/cuda_driver.cc:328] failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected\n",
            "2021-03-22 12:20:50.702051: I tensorflow/stream_executor/cuda/cuda_diagnostics.cc:156] kernel driver does not appear to be running on this host (eeb4af61eb99): /proc/driver/nvidia/version does not exist\n",
            "2021-03-22 12:20:52.229703: I tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc:196] None of the MLIR optimization passes are enabled (registered 0 passes)\n",
            "Done!\n",
            "params = defaultdict(<function fetch_model_params.<locals>.<lambda> at 0x7f64ee76fb90>, {'n_head': 20, 'n_vocab': 50257, 'embed_dropout': 0, 'lr': 0.00016, 'lr_decay': 'cosine', 'warmup_steps': 3000, 'beta1': 0.9, 'beta2': 0.95, 'epsilon': 1e-08, 'ada_epsilon1': '1e-30', 'ada_epsilon2': 0.001, 'opt_name': 'adam', 'weight_decay': 0, 'train_batch_size': 16, 'attn_dropout': 0, 'train_steps': 401000, 'lr_decay_end': 300000, 'eval_steps': 0, 'predict_steps': 0, 'res_dropout': 0, 'eval_batch_size': 128, 'predict_batch_size': 4, 'iterations': 500, 'n_embd': 2560, 'datasets': [['pile', None, None, None]], 'model_path': 'gs://test-bucket-neo/GPT3_2-7B', 'n_ctx': 2048, 'n_layer': 32, 'scale_by_depth': True, 'scale_by_in': False, 'attention_types': ['global', 'local', 'global', 'local', 'global', 'local', 'global', 'local', 'global', 'local', 'global', 'local', 'global', 'local', 'global', 'local', 'global', 'local', 'global', 'local', 'global', 'local', 'global', 'local', 'global', 'local', 'global', 'local', 'global', 'local', 'global', 'local'], 'mesh_shape': 'x:4,y:2', 'layout': 'intermediate_expanded:x,heads:x,memory_length:y,embd:y', 'activation_function': 'gelu', 'recompute_grad': True, 'gradient_clipping': 1.0, 'tokens_per_mb_per_replica': 4096, 'padding_id': 50257, 'eos_id': 50256, 'dataset_configs': {'pile': {'n_vocab': 50257, 'path': 'gs://neo-datasets/pile/pile_*.tfrecords', 'eval_path': 'gs://neo-datasets/pile_val.tfrecords', 'tokenizer_is_pretrained': True, 'tokenizer_path': 'gpt2', 'eos_id': 50256, 'padding_id': 50257}}, 'mlm_training': False, 'causal': True, 'num_cores': 8, 'auto_layout': False, 'auto_layout_and_mesh_shape': False, 'use_tpu': True, 'gpu_ids': ['device:GPU:0'], 'steps_per_checkpoint': 500, 'predict': True, 'model': 'GPT', 'export': False, 'sampling_use_entmax': False, 'moe_layers': None, 'slow_sampling': False})\n",
            "Using config: {'_model_dir': 'gs://test-bucket-neo/GPT3_2-7B', '_tf_random_seed': None, '_save_summary_steps': 500, '_save_checkpoints_steps': None, '_save_checkpoints_secs': None, '_session_config': allow_soft_placement: true\n",
            "cluster_def {\n",
            "  job {\n",
            "    name: \"worker\"\n",
            "    tasks {\n",
            "      key: 0\n",
            "      value: \"10.82.219.162:8470\"\n",
            "    }\n",
            "  }\n",
            "}\n",
            "isolate_session_state: true\n",
            ", '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': None, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({'worker': ['10.82.219.162:8470']}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': 'grpc://10.82.219.162:8470', '_evaluation_master': 'grpc://10.82.219.162:8470', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1, '_tpu_config': TPUConfig(iterations_per_loop=500, num_shards=8, num_cores_per_replica=1, per_host_input_for_training=4, tpu_job_name=None, initial_infeed_sleep_secs=None, input_partition_dims=None, eval_training_input_configuration=2, experimental_host_call_every_n_steps=1, experimental_allow_per_host_v2_parallel_get_next=False, experimental_feed_hook=None), '_cluster': <tensorflow.python.distribute.cluster_resolver.tpu.tpu_cluster_resolver.TPUClusterResolver object at 0x7f64ee774a90>}\n",
            "_TPUContext: eval_on_tpu True\n",
            "Predictions generated\n",
            "Querying Tensorflow master (grpc://10.82.219.162:8470) for TPU system metadata.\n",
            "2021-03-22 12:20:53.623443: W tensorflow/core/distributed_runtime/rpc/grpc_session.cc:373] GrpcSession::ListDevices will initialize the session with an empty graph and other defaults because the session has not yet been created.\n",
            "Initializing TPU system (master: grpc://10.82.219.162:8470) to fetch topology for model parallelism. This might take a while.\n",
            "Found TPU system:\n",
            "*** Num TPU Cores: 8\n",
            "*** Num TPU Workers: 1\n",
            "*** Num TPU Cores Per Worker: 8\n",
            "*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:CPU:0, CPU, -1, 6478766768852144079)\n",
            "*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:0, TPU, 17179869184, 1341089584581626564)\n",
            "*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:1, TPU, 17179869184, -607673649088781696)\n",
            "*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:2, TPU, 17179869184, -4050793109911027603)\n",
            "*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:3, TPU, 17179869184, -6683233089843062258)\n",
            "*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:4, TPU, 17179869184, -4741539030516422912)\n",
            "*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:5, TPU, 17179869184, 2164395643386766058)\n",
            "*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:6, TPU, 17179869184, 3352841220362516620)\n",
            "*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:7, TPU, 17179869184, 5726423099271110669)\n",
            "*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU_SYSTEM:0, TPU_SYSTEM, 8589934592, 7316344872981758207)\n",
            "*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:XLA_CPU:0, XLA_CPU, 17179869184, 7432402242254058183)\n",
            "Calling model_fn.\n",
            "num_cores_per_replica: 1\n",
            "computation_shape: [1, 1, 1, 1]\n",
            "num_replicas: 8\n",
            "device_assignment.topology.device_coordinates: [[[0 0 0 0]\n",
            "  [0 0 0 1]\n",
            "  [1 0 0 0]\n",
            "  [1 0 0 1]\n",
            "  [0 1 0 0]\n",
            "  [0 1 0 1]\n",
            "  [1 1 0 0]\n",
            "  [1 1 0 1]]]\n",
            "device_assignment.core_assignment: [[[0 0 0 0]]\n",
            "\n",
            " [[0 0 0 1]]\n",
            "\n",
            " [[1 0 0 0]]\n",
            "\n",
            " [[1 0 0 1]]\n",
            "\n",
            " [[0 1 0 0]]\n",
            "\n",
            " [[0 1 0 1]]\n",
            "\n",
            " [[1 1 0 0]]\n",
            "\n",
            " [[1 1 0 1]]]\n",
            "2021-03-22 12:21:11.005988: I tensorflow/compiler/jit/xla_cpu_device.cc:41] Not creating XLA devices, tf_xla_enable_xla_devices not set\n",
            "device_list = ['/job:worker/task:0/device:CPU:0']\n",
            "SimdMeshImpl ignoring devices ['', '', '', '', '', '', '', '']\n",
            "SimdMeshImpl init: Shape[x=4, y=2] LayoutRules{('heads', 'x'), ('embd', 'y'), ('intermediate_expanded', 'x'), ('memory_length', 'y')}\n",
            "Device Assignment: <tensorflow.python.tpu.device_assignment.DeviceAssignment object at 0x7f64e9078050>\n",
            "Create pnum_tensor\n",
            "Variable gpt2/h0/attn/k                                               size 6553600      slice_size 819200       Shape[embd=2560, heads=2560]                                \n",
            "Variable gpt2/h0/attn/o                                               size 6553600      slice_size 819200       Shape[heads=2560, embd=2560]                                \n",
            "Variable gpt2/h0/attn/q                                               size 6553600      slice_size 819200       Shape[embd=2560, heads=2560]                                \n",
            "Variable gpt2/h0/attn/v                                               size 6553600      slice_size 819200       Shape[embd=2560, heads=2560]                                \n",
            "Variable gpt2/h0/mlp/conv1d_main/c_fc/kernel                          size 26214400     slice_size 3276800      Shape[embd=2560, intermediate_expanded=10240]               \n",
            "Variable gpt2/h0/mlp/conv1d_main/c_proj/kernel                        size 26214400     slice_size 3276800      Shape[intermediate_expanded=10240, embd=2560]               \n",
            "Variable gpt2/h1/attn/k                                               size 6553600      slice_size 819200       Shape[embd=2560, heads=2560]                                \n",
            "Variable gpt2/h1/attn/o                                               size 6553600      slice_size 819200       Shape[heads=2560, embd=2560]                                \n",
            "Variable gpt2/h1/attn/q                                               size 6553600      slice_size 819200       Shape[embd=2560, heads=2560]                                \n",
            "Variable gpt2/h1/attn/v                                               size 6553600      slice_size 819200       Shape[embd=2560, heads=2560]                                \n",
            "Variable gpt2/h1/mlp/conv1d_main/c_fc/kernel                          size 26214400     slice_size 3276800      Shape[embd=2560, intermediate_expanded=10240]               \n",
            "Variable gpt2/h1/mlp/conv1d_main/c_proj/kernel                        size 26214400     slice_size 3276800      Shape[intermediate_expanded=10240, embd=2560]               \n",
            "Variable gpt2/h10/attn/k                                              size 6553600      slice_size 819200       Shape[embd=2560, heads=2560]                                \n",
            "Variable gpt2/h10/attn/o                                              size 6553600      slice_size 819200       Shape[heads=2560, embd=2560]                                \n",
            "Variable gpt2/h10/attn/q                                              size 6553600      slice_size 819200       Shape[embd=2560, heads=2560]                                \n",
            "Variable gpt2/h10/attn/v                                              size 6553600      slice_size 819200       Shape[embd=2560, heads=2560]                                \n",
            "Variable gpt2/h10/mlp/conv1d_main/c_fc/kernel                         size 26214400     slice_size 3276800      Shape[embd=2560, intermediate_expanded=10240]               \n",
            "Variable gpt2/h10/mlp/conv1d_main/c_proj/kernel                       size 26214400     slice_size 3276800      Shape[intermediate_expanded=10240, embd=2560]               \n",
            "Variable gpt2/h11/attn/k                                              size 6553600      slice_size 819200       Shape[embd=2560, heads=2560]                                \n",
            "Variable gpt2/h11/attn/o                                              size 6553600      slice_size 819200       Shape[heads=2560, embd=2560]                                \n",
            "Variable gpt2/h11/attn/q                                              size 6553600      slice_size 819200       Shape[embd=2560, heads=2560]                                \n",
            "Variable gpt2/h11/attn/v                                              size 6553600      slice_size 819200       Shape[embd=2560, heads=2560]                                \n",
            "Variable gpt2/h11/mlp/conv1d_main/c_fc/kernel                         size 26214400     slice_size 3276800      Shape[embd=2560, intermediate_expanded=10240]               \n",
            "Variable gpt2/h11/mlp/conv1d_main/c_proj/kernel                       size 26214400     slice_size 3276800      Shape[intermediate_expanded=10240, embd=2560]               \n",
            "Variable gpt2/h12/attn/k                                              size 6553600      slice_size 819200       Shape[embd=2560, heads=2560]                                \n",
            "Variable gpt2/h12/attn/o                                              size 6553600      slice_size 819200       Shape[heads=2560, embd=2560]                                \n",
            "Variable gpt2/h12/attn/q                                              size 6553600      slice_size 819200       Shape[embd=2560, heads=2560]                                \n",
            "Variable gpt2/h12/attn/v                                              size 6553600      slice_size 819200       Shape[embd=2560, heads=2560]                                \n",
            "Variable gpt2/h12/mlp/conv1d_main/c_fc/kernel                         size 26214400     slice_size 3276800      Shape[embd=2560, intermediate_expanded=10240]               \n",
            "Variable gpt2/h12/mlp/conv1d_main/c_proj/kernel                       size 26214400     slice_size 3276800      Shape[intermediate_expanded=10240, embd=2560]               \n",
            "Variable gpt2/h13/attn/k                                              size 6553600      slice_size 819200       Shape[embd=2560, heads=2560]                                \n",
            "Variable gpt2/h13/attn/o                                              size 6553600      slice_size 819200       Shape[heads=2560, embd=2560]                                \n",
            "Variable gpt2/h13/attn/q                                              size 6553600      slice_size 819200       Shape[embd=2560, heads=2560]                                \n",
            "Variable gpt2/h13/attn/v                                              size 6553600      slice_size 819200       Shape[embd=2560, heads=2560]                                \n",
            "Variable gpt2/h13/mlp/conv1d_main/c_fc/kernel                         size 26214400     slice_size 3276800      Shape[embd=2560, intermediate_expanded=10240]               \n",
            "Variable gpt2/h13/mlp/conv1d_main/c_proj/kernel                       size 26214400     slice_size 3276800      Shape[intermediate_expanded=10240, embd=2560]               \n",
            "Variable gpt2/h14/attn/k                                              size 6553600      slice_size 819200       Shape[embd=2560, heads=2560]                                \n",
            "Variable gpt2/h14/attn/o                                              size 6553600      slice_size 819200       Shape[heads=2560, embd=2560]                                \n",
            "Variable gpt2/h14/attn/q                                              size 6553600      slice_size 819200       Shape[embd=2560, heads=2560]                                \n",
            "Variable gpt2/h14/attn/v                                              size 6553600      slice_size 819200       Shape[embd=2560, heads=2560]                                \n",
            "Variable gpt2/h14/mlp/conv1d_main/c_fc/kernel                         size 26214400     slice_size 3276800      Shape[embd=2560, intermediate_expanded=10240]               \n",
            "Variable gpt2/h14/mlp/conv1d_main/c_proj/kernel                       size 26214400     slice_size 3276800      Shape[intermediate_expanded=10240, embd=2560]               \n",
            "Variable gpt2/h15/attn/k                                              size 6553600      slice_size 819200       Shape[embd=2560, heads=2560]                                \n",
            "Variable gpt2/h15/attn/o                                              size 6553600      slice_size 819200       Shape[heads=2560, embd=2560]                                \n",
            "Variable gpt2/h15/attn/q                                              size 6553600      slice_size 819200       Shape[embd=2560, heads=2560]                                \n",
            "Variable gpt2/h15/attn/v                                              size 6553600      slice_size 819200       Shape[embd=2560, heads=2560]                                \n",
            "Variable gpt2/h15/mlp/conv1d_main/c_fc/kernel                         size 26214400     slice_size 3276800      Shape[embd=2560, intermediate_expanded=10240]               \n",
            "Variable gpt2/h15/mlp/conv1d_main/c_proj/kernel                       size 26214400     slice_size 3276800      Shape[intermediate_expanded=10240, embd=2560]               \n",
            "Variable gpt2/h16/attn/k                                              size 6553600      slice_size 819200       Shape[embd=2560, heads=2560]                                \n",
            "Variable gpt2/h16/attn/o                                              size 6553600      slice_size 819200       Shape[heads=2560, embd=2560]                                \n",
            "Variable gpt2/h16/attn/q                                              size 6553600      slice_size 819200       Shape[embd=2560, heads=2560]                                \n",
            "Variable gpt2/h16/attn/v                                              size 6553600      slice_size 819200       Shape[embd=2560, heads=2560]                                \n",
            "Variable gpt2/h16/mlp/conv1d_main/c_fc/kernel                         size 26214400     slice_size 3276800      Shape[embd=2560, intermediate_expanded=10240]               \n",
            "Variable gpt2/h16/mlp/conv1d_main/c_proj/kernel                       size 26214400     slice_size 3276800      Shape[intermediate_expanded=10240, embd=2560]               \n",
            "Variable gpt2/h17/attn/k                                              size 6553600      slice_size 819200       Shape[embd=2560, heads=2560]                                \n",
            "Variable gpt2/h17/attn/o                                              size 6553600      slice_size 819200       Shape[heads=2560, embd=2560]                                \n",
            "Variable gpt2/h17/attn/q                                              size 6553600      slice_size 819200       Shape[embd=2560, heads=2560]                                \n",
            "Variable gpt2/h17/attn/v                                              size 6553600      slice_size 819200       Shape[embd=2560, heads=2560]                                \n",
            "Variable gpt2/h17/mlp/conv1d_main/c_fc/kernel                         size 26214400     slice_size 3276800      Shape[embd=2560, intermediate_expanded=10240]               \n",
            "Variable gpt2/h17/mlp/conv1d_main/c_proj/kernel                       size 26214400     slice_size 3276800      Shape[intermediate_expanded=10240, embd=2560]               \n",
            "Variable gpt2/h18/attn/k                                              size 6553600      slice_size 819200       Shape[embd=2560, heads=2560]                                \n",
            "Variable gpt2/h18/attn/o                                              size 6553600      slice_size 819200       Shape[heads=2560, embd=2560]                                \n",
            "Variable gpt2/h18/attn/q                                              size 6553600      slice_size 819200       Shape[embd=2560, heads=2560]                                \n",
            "Variable gpt2/h18/attn/v                                              size 6553600      slice_size 819200       Shape[embd=2560, heads=2560]                                \n",
            "Variable gpt2/h18/mlp/conv1d_main/c_fc/kernel                         size 26214400     slice_size 3276800      Shape[embd=2560, intermediate_expanded=10240]               \n",
            "Variable gpt2/h18/mlp/conv1d_main/c_proj/kernel                       size 26214400     slice_size 3276800      Shape[intermediate_expanded=10240, embd=2560]               \n",
            "Variable gpt2/h19/attn/k                                              size 6553600      slice_size 819200       Shape[embd=2560, heads=2560]                                \n",
            "Variable gpt2/h19/attn/o                                              size 6553600      slice_size 819200       Shape[heads=2560, embd=2560]                                \n",
            "Variable gpt2/h19/attn/q                                              size 6553600      slice_size 819200       Shape[embd=2560, heads=2560]                                \n",
            "Variable gpt2/h19/attn/v                                              size 6553600      slice_size 819200       Shape[embd=2560, heads=2560]                                \n",
            "Variable gpt2/h19/mlp/conv1d_main/c_fc/kernel                         size 26214400     slice_size 3276800      Shape[embd=2560, intermediate_expanded=10240]               \n",
            "Variable gpt2/h19/mlp/conv1d_main/c_proj/kernel                       size 26214400     slice_size 3276800      Shape[intermediate_expanded=10240, embd=2560]               \n",
            "Variable gpt2/h2/attn/k                                               size 6553600      slice_size 819200       Shape[embd=2560, heads=2560]                                \n",
            "Variable gpt2/h2/attn/o                                               size 6553600      slice_size 819200       Shape[heads=2560, embd=2560]                                \n",
            "Variable gpt2/h2/attn/q                                               size 6553600      slice_size 819200       Shape[embd=2560, heads=2560]                                \n",
            "Variable gpt2/h2/attn/v                                               size 6553600      slice_size 819200       Shape[embd=2560, heads=2560]                                \n",
            "Variable gpt2/h2/mlp/conv1d_main/c_fc/kernel                          size 26214400     slice_size 3276800      Shape[embd=2560, intermediate_expanded=10240]               \n",
            "Variable gpt2/h2/mlp/conv1d_main/c_proj/kernel                        size 26214400     slice_size 3276800      Shape[intermediate_expanded=10240, embd=2560]               \n",
            "Variable gpt2/h20/attn/k                                              size 6553600      slice_size 819200       Shape[embd=2560, heads=2560]                                \n",
            "Variable gpt2/h20/attn/o                                              size 6553600      slice_size 819200       Shape[heads=2560, embd=2560]                                \n",
            "Variable gpt2/h20/attn/q                                              size 6553600      slice_size 819200       Shape[embd=2560, heads=2560]                                \n",
            "Variable gpt2/h20/attn/v                                              size 6553600      slice_size 819200       Shape[embd=2560, heads=2560]                                \n",
            "Variable gpt2/h20/mlp/conv1d_main/c_fc/kernel                         size 26214400     slice_size 3276800      Shape[embd=2560, intermediate_expanded=10240]               \n",
            "Variable gpt2/h20/mlp/conv1d_main/c_proj/kernel                       size 26214400     slice_size 3276800      Shape[intermediate_expanded=10240, embd=2560]               \n",
            "Variable gpt2/h21/attn/k                                              size 6553600      slice_size 819200       Shape[embd=2560, heads=2560]                                \n",
            "Variable gpt2/h21/attn/o                                              size 6553600      slice_size 819200       Shape[heads=2560, embd=2560]                                \n",
            "Variable gpt2/h21/attn/q                                              size 6553600      slice_size 819200       Shape[embd=2560, heads=2560]                                \n",
            "Variable gpt2/h21/attn/v                                              size 6553600      slice_size 819200       Shape[embd=2560, heads=2560]                                \n",
            "Variable gpt2/h21/mlp/conv1d_main/c_fc/kernel                         size 26214400     slice_size 3276800      Shape[embd=2560, intermediate_expanded=10240]               \n",
            "Variable gpt2/h21/mlp/conv1d_main/c_proj/kernel                       size 26214400     slice_size 3276800      Shape[intermediate_expanded=10240, embd=2560]               \n",
            "Variable gpt2/h22/attn/k                                              size 6553600      slice_size 819200       Shape[embd=2560, heads=2560]                                \n",
            "Variable gpt2/h22/attn/o                                              size 6553600      slice_size 819200       Shape[heads=2560, embd=2560]                                \n",
            "Variable gpt2/h22/attn/q                                              size 6553600      slice_size 819200       Shape[embd=2560, heads=2560]                                \n",
            "Variable gpt2/h22/attn/v                                              size 6553600      slice_size 819200       Shape[embd=2560, heads=2560]                                \n",
            "Variable gpt2/h22/mlp/conv1d_main/c_fc/kernel                         size 26214400     slice_size 3276800      Shape[embd=2560, intermediate_expanded=10240]               \n",
            "Variable gpt2/h22/mlp/conv1d_main/c_proj/kernel                       size 26214400     slice_size 3276800      Shape[intermediate_expanded=10240, embd=2560]               \n",
            "Variable gpt2/h23/attn/k                                              size 6553600      slice_size 819200       Shape[embd=2560, heads=2560]                                \n",
            "Variable gpt2/h23/attn/o                                              size 6553600      slice_size 819200       Shape[heads=2560, embd=2560]                                \n",
            "Variable gpt2/h23/attn/q                                              size 6553600      slice_size 819200       Shape[embd=2560, heads=2560]                                \n",
            "Variable gpt2/h23/attn/v                                              size 6553600      slice_size 819200       Shape[embd=2560, heads=2560]                                \n",
            "Variable gpt2/h23/mlp/conv1d_main/c_fc/kernel                         size 26214400     slice_size 3276800      Shape[embd=2560, intermediate_expanded=10240]               \n",
            "Variable gpt2/h23/mlp/conv1d_main/c_proj/kernel                       size 26214400     slice_size 3276800      Shape[intermediate_expanded=10240, embd=2560]               \n",
            "Variable gpt2/h24/attn/k                                              size 6553600      slice_size 819200       Shape[embd=2560, heads=2560]                                \n",
            "Variable gpt2/h24/attn/o                                              size 6553600      slice_size 819200       Shape[heads=2560, embd=2560]                                \n",
            "Variable gpt2/h24/attn/q                                              size 6553600      slice_size 819200       Shape[embd=2560, heads=2560]                                \n",
            "Variable gpt2/h24/attn/v                                              size 6553600      slice_size 819200       Shape[embd=2560, heads=2560]                                \n",
            "Variable gpt2/h24/mlp/conv1d_main/c_fc/kernel                         size 26214400     slice_size 3276800      Shape[embd=2560, intermediate_expanded=10240]               \n",
            "Variable gpt2/h24/mlp/conv1d_main/c_proj/kernel                       size 26214400     slice_size 3276800      Shape[intermediate_expanded=10240, embd=2560]               \n",
            "Variable gpt2/h25/attn/k                                              size 6553600      slice_size 819200       Shape[embd=2560, heads=2560]                                \n",
            "Variable gpt2/h25/attn/o                                              size 6553600      slice_size 819200       Shape[heads=2560, embd=2560]                                \n",
            "Variable gpt2/h25/attn/q                                              size 6553600      slice_size 819200       Shape[embd=2560, heads=2560]                                \n",
            "Variable gpt2/h25/attn/v                                              size 6553600      slice_size 819200       Shape[embd=2560, heads=2560]                                \n",
            "Variable gpt2/h25/mlp/conv1d_main/c_fc/kernel                         size 26214400     slice_size 3276800      Shape[embd=2560, intermediate_expanded=10240]               \n",
            "Variable gpt2/h25/mlp/conv1d_main/c_proj/kernel                       size 26214400     slice_size 3276800      Shape[intermediate_expanded=10240, embd=2560]               \n",
            "Variable gpt2/h26/attn/k                                              size 6553600      slice_size 819200       Shape[embd=2560, heads=2560]                                \n",
            "Variable gpt2/h26/attn/o                                              size 6553600      slice_size 819200       Shape[heads=2560, embd=2560]                                \n",
            "Variable gpt2/h26/attn/q                                              size 6553600      slice_size 819200       Shape[embd=2560, heads=2560]                                \n",
            "Variable gpt2/h26/attn/v                                              size 6553600      slice_size 819200       Shape[embd=2560, heads=2560]                                \n",
            "Variable gpt2/h26/mlp/conv1d_main/c_fc/kernel                         size 26214400     slice_size 3276800      Shape[embd=2560, intermediate_expanded=10240]               \n",
            "Variable gpt2/h26/mlp/conv1d_main/c_proj/kernel                       size 26214400     slice_size 3276800      Shape[intermediate_expanded=10240, embd=2560]               \n",
            "Variable gpt2/h27/attn/k                                              size 6553600      slice_size 819200       Shape[embd=2560, heads=2560]                                \n",
            "Variable gpt2/h27/attn/o                                              size 6553600      slice_size 819200       Shape[heads=2560, embd=2560]                                \n",
            "Variable gpt2/h27/attn/q                                              size 6553600      slice_size 819200       Shape[embd=2560, heads=2560]                                \n",
            "Variable gpt2/h27/attn/v                                              size 6553600      slice_size 819200       Shape[embd=2560, heads=2560]                                \n",
            "Variable gpt2/h27/mlp/conv1d_main/c_fc/kernel                         size 26214400     slice_size 3276800      Shape[embd=2560, intermediate_expanded=10240]               \n",
            "Variable gpt2/h27/mlp/conv1d_main/c_proj/kernel                       size 26214400     slice_size 3276800      Shape[intermediate_expanded=10240, embd=2560]               \n",
            "Variable gpt2/h28/attn/k                                              size 6553600      slice_size 819200       Shape[embd=2560, heads=2560]                                \n",
            "Variable gpt2/h28/attn/o                                              size 6553600      slice_size 819200       Shape[heads=2560, embd=2560]                                \n",
            "Variable gpt2/h28/attn/q                                              size 6553600      slice_size 819200       Shape[embd=2560, heads=2560]                                \n",
            "Variable gpt2/h28/attn/v                                              size 6553600      slice_size 819200       Shape[embd=2560, heads=2560]                                \n",
            "Variable gpt2/h28/mlp/conv1d_main/c_fc/kernel                         size 26214400     slice_size 3276800      Shape[embd=2560, intermediate_expanded=10240]               \n",
            "Variable gpt2/h28/mlp/conv1d_main/c_proj/kernel                       size 26214400     slice_size 3276800      Shape[intermediate_expanded=10240, embd=2560]               \n",
            "Variable gpt2/h29/attn/k                                              size 6553600      slice_size 819200       Shape[embd=2560, heads=2560]                                \n",
            "Variable gpt2/h29/attn/o                                              size 6553600      slice_size 819200       Shape[heads=2560, embd=2560]                                \n",
            "Variable gpt2/h29/attn/q                                              size 6553600      slice_size 819200       Shape[embd=2560, heads=2560]                                \n",
            "Variable gpt2/h29/attn/v                                              size 6553600      slice_size 819200       Shape[embd=2560, heads=2560]                                \n",
            "Variable gpt2/h29/mlp/conv1d_main/c_fc/kernel                         size 26214400     slice_size 3276800      Shape[embd=2560, intermediate_expanded=10240]               \n",
            "Variable gpt2/h29/mlp/conv1d_main/c_proj/kernel                       size 26214400     slice_size 3276800      Shape[intermediate_expanded=10240, embd=2560]               \n",
            "Variable gpt2/h3/attn/k                                               size 6553600      slice_size 819200       Shape[embd=2560, heads=2560]                                \n",
            "Variable gpt2/h3/attn/o                                               size 6553600      slice_size 819200       Shape[heads=2560, embd=2560]                                \n",
            "Variable gpt2/h3/attn/q                                               size 6553600      slice_size 819200       Shape[embd=2560, heads=2560]                                \n",
            "Variable gpt2/h3/attn/v                                               size 6553600      slice_size 819200       Shape[embd=2560, heads=2560]                                \n",
            "Variable gpt2/h3/mlp/conv1d_main/c_fc/kernel                          size 26214400     slice_size 3276800      Shape[embd=2560, intermediate_expanded=10240]               \n",
            "Variable gpt2/h3/mlp/conv1d_main/c_proj/kernel                        size 26214400     slice_size 3276800      Shape[intermediate_expanded=10240, embd=2560]               \n",
            "Variable gpt2/h30/attn/k                                              size 6553600      slice_size 819200       Shape[embd=2560, heads=2560]                                \n",
            "Variable gpt2/h30/attn/o                                              size 6553600      slice_size 819200       Shape[heads=2560, embd=2560]                                \n",
            "Variable gpt2/h30/attn/q                                              size 6553600      slice_size 819200       Shape[embd=2560, heads=2560]                                \n",
            "Variable gpt2/h30/attn/v                                              size 6553600      slice_size 819200       Shape[embd=2560, heads=2560]                                \n",
            "Variable gpt2/h30/mlp/conv1d_main/c_fc/kernel                         size 26214400     slice_size 3276800      Shape[embd=2560, intermediate_expanded=10240]               \n",
            "Variable gpt2/h30/mlp/conv1d_main/c_proj/kernel                       size 26214400     slice_size 3276800      Shape[intermediate_expanded=10240, embd=2560]               \n",
            "Variable gpt2/h31/attn/k                                              size 6553600      slice_size 819200       Shape[embd=2560, heads=2560]                                \n",
            "Variable gpt2/h31/attn/o                                              size 6553600      slice_size 819200       Shape[heads=2560, embd=2560]                                \n",
            "Variable gpt2/h31/attn/q                                              size 6553600      slice_size 819200       Shape[embd=2560, heads=2560]                                \n",
            "Variable gpt2/h31/attn/v                                              size 6553600      slice_size 819200       Shape[embd=2560, heads=2560]                                \n",
            "Variable gpt2/h31/mlp/conv1d_main/c_fc/kernel                         size 26214400     slice_size 3276800      Shape[embd=2560, intermediate_expanded=10240]               \n",
            "Variable gpt2/h31/mlp/conv1d_main/c_proj/kernel                       size 26214400     slice_size 3276800      Shape[intermediate_expanded=10240, embd=2560]               \n",
            "Variable gpt2/h4/attn/k                                               size 6553600      slice_size 819200       Shape[embd=2560, heads=2560]                                \n",
            "Variable gpt2/h4/attn/o                                               size 6553600      slice_size 819200       Shape[heads=2560, embd=2560]                                \n",
            "Variable gpt2/h4/attn/q                                               size 6553600      slice_size 819200       Shape[embd=2560, heads=2560]                                \n",
            "Variable gpt2/h4/attn/v                                               size 6553600      slice_size 819200       Shape[embd=2560, heads=2560]                                \n",
            "Variable gpt2/h4/mlp/conv1d_main/c_fc/kernel                          size 26214400     slice_size 3276800      Shape[embd=2560, intermediate_expanded=10240]               \n",
            "Variable gpt2/h4/mlp/conv1d_main/c_proj/kernel                        size 26214400     slice_size 3276800      Shape[intermediate_expanded=10240, embd=2560]               \n",
            "Variable gpt2/h5/attn/k                                               size 6553600      slice_size 819200       Shape[embd=2560, heads=2560]                                \n",
            "Variable gpt2/h5/attn/o                                               size 6553600      slice_size 819200       Shape[heads=2560, embd=2560]                                \n",
            "Variable gpt2/h5/attn/q                                               size 6553600      slice_size 819200       Shape[embd=2560, heads=2560]                                \n",
            "Variable gpt2/h5/attn/v                                               size 6553600      slice_size 819200       Shape[embd=2560, heads=2560]                                \n",
            "Variable gpt2/h5/mlp/conv1d_main/c_fc/kernel                          size 26214400     slice_size 3276800      Shape[embd=2560, intermediate_expanded=10240]               \n",
            "Variable gpt2/h5/mlp/conv1d_main/c_proj/kernel                        size 26214400     slice_size 3276800      Shape[intermediate_expanded=10240, embd=2560]               \n",
            "Variable gpt2/h6/attn/k                                               size 6553600      slice_size 819200       Shape[embd=2560, heads=2560]                                \n",
            "Variable gpt2/h6/attn/o                                               size 6553600      slice_size 819200       Shape[heads=2560, embd=2560]                                \n",
            "Variable gpt2/h6/attn/q                                               size 6553600      slice_size 819200       Shape[embd=2560, heads=2560]                                \n",
            "Variable gpt2/h6/attn/v                                               size 6553600      slice_size 819200       Shape[embd=2560, heads=2560]                                \n",
            "Variable gpt2/h6/mlp/conv1d_main/c_fc/kernel                          size 26214400     slice_size 3276800      Shape[embd=2560, intermediate_expanded=10240]               \n",
            "Variable gpt2/h6/mlp/conv1d_main/c_proj/kernel                        size 26214400     slice_size 3276800      Shape[intermediate_expanded=10240, embd=2560]               \n",
            "Variable gpt2/h7/attn/k                                               size 6553600      slice_size 819200       Shape[embd=2560, heads=2560]                                \n",
            "Variable gpt2/h7/attn/o                                               size 6553600      slice_size 819200       Shape[heads=2560, embd=2560]                                \n",
            "Variable gpt2/h7/attn/q                                               size 6553600      slice_size 819200       Shape[embd=2560, heads=2560]                                \n",
            "Variable gpt2/h7/attn/v                                               size 6553600      slice_size 819200       Shape[embd=2560, heads=2560]                                \n",
            "Variable gpt2/h7/mlp/conv1d_main/c_fc/kernel                          size 26214400     slice_size 3276800      Shape[embd=2560, intermediate_expanded=10240]               \n",
            "Variable gpt2/h7/mlp/conv1d_main/c_proj/kernel                        size 26214400     slice_size 3276800      Shape[intermediate_expanded=10240, embd=2560]               \n",
            "Variable gpt2/h8/attn/k                                               size 6553600      slice_size 819200       Shape[embd=2560, heads=2560]                                \n",
            "Variable gpt2/h8/attn/o                                               size 6553600      slice_size 819200       Shape[heads=2560, embd=2560]                                \n",
            "Variable gpt2/h8/attn/q                                               size 6553600      slice_size 819200       Shape[embd=2560, heads=2560]                                \n",
            "Variable gpt2/h8/attn/v                                               size 6553600      slice_size 819200       Shape[embd=2560, heads=2560]                                \n",
            "Variable gpt2/h8/mlp/conv1d_main/c_fc/kernel                          size 26214400     slice_size 3276800      Shape[embd=2560, intermediate_expanded=10240]               \n",
            "Variable gpt2/h8/mlp/conv1d_main/c_proj/kernel                        size 26214400     slice_size 3276800      Shape[intermediate_expanded=10240, embd=2560]               \n",
            "Variable gpt2/h9/attn/k                                               size 6553600      slice_size 819200       Shape[embd=2560, heads=2560]                                \n",
            "Variable gpt2/h9/attn/o                                               size 6553600      slice_size 819200       Shape[heads=2560, embd=2560]                                \n",
            "Variable gpt2/h9/attn/q                                               size 6553600      slice_size 819200       Shape[embd=2560, heads=2560]                                \n",
            "Variable gpt2/h9/attn/v                                               size 6553600      slice_size 819200       Shape[embd=2560, heads=2560]                                \n",
            "Variable gpt2/h9/mlp/conv1d_main/c_fc/kernel                          size 26214400     slice_size 3276800      Shape[embd=2560, intermediate_expanded=10240]               \n",
            "Variable gpt2/h9/mlp/conv1d_main/c_proj/kernel                        size 26214400     slice_size 3276800      Shape[intermediate_expanded=10240, embd=2560]               \n",
            "Variable gpt2/wpe                                                     size 5242880      slice_size 2621440      Shape[embed_sequence=2048, embd=2560]                       \n",
            "Variable gpt2/wte                                                     size 128657920    slice_size 64328960     Shape[vocab=50257, embd=2560]                               \n",
            "Variable stacked/gpt2/h0/mlp/conv1d_main/c_fc/bias                    size 256000       slice_size 64000        Shape[stacked=25, intermediate_expanded=10240]              \n",
            "    gpt2/h0/mlp/conv1d_main/c_fc/bias\n",
            "    gpt2/h1/mlp/conv1d_main/c_fc/bias\n",
            "    gpt2/h2/mlp/conv1d_main/c_fc/bias\n",
            "    gpt2/h3/mlp/conv1d_main/c_fc/bias\n",
            "    gpt2/h4/mlp/conv1d_main/c_fc/bias\n",
            "    gpt2/h5/mlp/conv1d_main/c_fc/bias\n",
            "    gpt2/h6/mlp/conv1d_main/c_fc/bias\n",
            "    gpt2/h7/mlp/conv1d_main/c_fc/bias\n",
            "    gpt2/h8/mlp/conv1d_main/c_fc/bias\n",
            "    gpt2/h9/mlp/conv1d_main/c_fc/bias\n",
            "    gpt2/h10/mlp/conv1d_main/c_fc/bias\n",
            "    gpt2/h11/mlp/conv1d_main/c_fc/bias\n",
            "    gpt2/h12/mlp/conv1d_main/c_fc/bias\n",
            "    gpt2/h13/mlp/conv1d_main/c_fc/bias\n",
            "    gpt2/h14/mlp/conv1d_main/c_fc/bias\n",
            "    gpt2/h15/mlp/conv1d_main/c_fc/bias\n",
            "    gpt2/h16/mlp/conv1d_main/c_fc/bias\n",
            "    gpt2/h17/mlp/conv1d_main/c_fc/bias\n",
            "    gpt2/h18/mlp/conv1d_main/c_fc/bias\n",
            "    gpt2/h19/mlp/conv1d_main/c_fc/bias\n",
            "    gpt2/h20/mlp/conv1d_main/c_fc/bias\n",
            "    gpt2/h21/mlp/conv1d_main/c_fc/bias\n",
            "    gpt2/h22/mlp/conv1d_main/c_fc/bias\n",
            "    gpt2/h23/mlp/conv1d_main/c_fc/bias\n",
            "    gpt2/h24/mlp/conv1d_main/c_fc/bias\n",
            "Variable stacked/gpt2/h0/norm_1/g                                     size 130560       slice_size 65280        Shape[stacked=51, embd=2560]                                \n",
            "    gpt2/h0/norm_1/g\n",
            "    gpt2/h0/norm_1/b\n",
            "    gpt2/h0/attn/compute_output_bias/o_b\n",
            "    gpt2/h0/norm_2/g\n",
            "    gpt2/h0/norm_2/b\n",
            "    gpt2/h0/mlp/conv1d_main/c_proj/bias\n",
            "    gpt2/h1/norm_1/g\n",
            "    gpt2/h1/norm_1/b\n",
            "    gpt2/h1/attn/compute_output_bias/o_b\n",
            "    gpt2/h1/norm_2/g\n",
            "    gpt2/h1/norm_2/b\n",
            "    gpt2/h1/mlp/conv1d_main/c_proj/bias\n",
            "    gpt2/h2/norm_1/g\n",
            "    gpt2/h2/norm_1/b\n",
            "    gpt2/h2/attn/compute_output_bias/o_b\n",
            "    gpt2/h2/norm_2/g\n",
            "    gpt2/h2/norm_2/b\n",
            "    gpt2/h2/mlp/conv1d_main/c_proj/bias\n",
            "    gpt2/h3/norm_1/g\n",
            "    gpt2/h3/norm_1/b\n",
            "    gpt2/h3/attn/compute_output_bias/o_b\n",
            "    gpt2/h3/norm_2/g\n",
            "    gpt2/h3/norm_2/b\n",
            "    gpt2/h3/mlp/conv1d_main/c_proj/bias\n",
            "    gpt2/h4/norm_1/g\n",
            "    gpt2/h4/norm_1/b\n",
            "    gpt2/h4/attn/compute_output_bias/o_b\n",
            "    gpt2/h4/norm_2/g\n",
            "    gpt2/h4/norm_2/b\n",
            "    gpt2/h4/mlp/conv1d_main/c_proj/bias\n",
            "    gpt2/h5/norm_1/g\n",
            "    gpt2/h5/norm_1/b\n",
            "    gpt2/h5/attn/compute_output_bias/o_b\n",
            "    gpt2/h5/norm_2/g\n",
            "    gpt2/h5/norm_2/b\n",
            "    gpt2/h5/mlp/conv1d_main/c_proj/bias\n",
            "    gpt2/h6/norm_1/g\n",
            "    gpt2/h6/norm_1/b\n",
            "    gpt2/h6/attn/compute_output_bias/o_b\n",
            "    gpt2/h6/norm_2/g\n",
            "    gpt2/h6/norm_2/b\n",
            "    gpt2/h6/mlp/conv1d_main/c_proj/bias\n",
            "    gpt2/h7/norm_1/g\n",
            "    gpt2/h7/norm_1/b\n",
            "    gpt2/h7/attn/compute_output_bias/o_b\n",
            "    gpt2/h7/norm_2/g\n",
            "    gpt2/h7/norm_2/b\n",
            "    gpt2/h7/mlp/conv1d_main/c_proj/bias\n",
            "    gpt2/h8/norm_1/g\n",
            "    gpt2/h8/norm_1/b\n",
            "    gpt2/h8/attn/compute_output_bias/o_b\n",
            "Variable stacked/gpt2/h17/norm_1/g                                    size 130560       slice_size 65280        Shape[stacked=51, embd=2560]                                \n",
            "    gpt2/h17/norm_1/g\n",
            "    gpt2/h17/norm_1/b\n",
            "    gpt2/h17/attn/compute_output_bias/o_b\n",
            "    gpt2/h17/norm_2/g\n",
            "    gpt2/h17/norm_2/b\n",
            "    gpt2/h17/mlp/conv1d_main/c_proj/bias\n",
            "    gpt2/h18/norm_1/g\n",
            "    gpt2/h18/norm_1/b\n",
            "    gpt2/h18/attn/compute_output_bias/o_b\n",
            "    gpt2/h18/norm_2/g\n",
            "    gpt2/h18/norm_2/b\n",
            "    gpt2/h18/mlp/conv1d_main/c_proj/bias\n",
            "    gpt2/h19/norm_1/g\n",
            "    gpt2/h19/norm_1/b\n",
            "    gpt2/h19/attn/compute_output_bias/o_b\n",
            "    gpt2/h19/norm_2/g\n",
            "    gpt2/h19/norm_2/b\n",
            "    gpt2/h19/mlp/conv1d_main/c_proj/bias\n",
            "    gpt2/h20/norm_1/g\n",
            "    gpt2/h20/norm_1/b\n",
            "    gpt2/h20/attn/compute_output_bias/o_b\n",
            "    gpt2/h20/norm_2/g\n",
            "    gpt2/h20/norm_2/b\n",
            "    gpt2/h20/mlp/conv1d_main/c_proj/bias\n",
            "    gpt2/h21/norm_1/g\n",
            "    gpt2/h21/norm_1/b\n",
            "    gpt2/h21/attn/compute_output_bias/o_b\n",
            "    gpt2/h21/norm_2/g\n",
            "    gpt2/h21/norm_2/b\n",
            "    gpt2/h21/mlp/conv1d_main/c_proj/bias\n",
            "    gpt2/h22/norm_1/g\n",
            "    gpt2/h22/norm_1/b\n",
            "    gpt2/h22/attn/compute_output_bias/o_b\n",
            "    gpt2/h22/norm_2/g\n",
            "    gpt2/h22/norm_2/b\n",
            "    gpt2/h22/mlp/conv1d_main/c_proj/bias\n",
            "    gpt2/h23/norm_1/g\n",
            "    gpt2/h23/norm_1/b\n",
            "    gpt2/h23/attn/compute_output_bias/o_b\n",
            "    gpt2/h23/norm_2/g\n",
            "    gpt2/h23/norm_2/b\n",
            "    gpt2/h23/mlp/conv1d_main/c_proj/bias\n",
            "    gpt2/h24/norm_1/g\n",
            "    gpt2/h24/norm_1/b\n",
            "    gpt2/h24/attn/compute_output_bias/o_b\n",
            "    gpt2/h24/norm_2/g\n",
            "    gpt2/h24/norm_2/b\n",
            "    gpt2/h24/mlp/conv1d_main/c_proj/bias\n",
            "    gpt2/h25/norm_1/g\n",
            "    gpt2/h25/norm_1/b\n",
            "    gpt2/h25/attn/compute_output_bias/o_b\n",
            "Variable stacked/gpt2/h25/mlp/conv1d_main/c_fc/bias                   size 71680        slice_size 17920        Shape[stacked=7, intermediate_expanded=10240]               \n",
            "    gpt2/h25/mlp/conv1d_main/c_fc/bias\n",
            "    gpt2/h26/mlp/conv1d_main/c_fc/bias\n",
            "    gpt2/h27/mlp/conv1d_main/c_fc/bias\n",
            "    gpt2/h28/mlp/conv1d_main/c_fc/bias\n",
            "    gpt2/h29/mlp/conv1d_main/c_fc/bias\n",
            "    gpt2/h30/mlp/conv1d_main/c_fc/bias\n",
            "    gpt2/h31/mlp/conv1d_main/c_fc/bias\n",
            "Variable stacked/gpt2/h25/norm_2/g                                    size 104960       slice_size 52480        Shape[stacked=41, embd=2560]                                \n",
            "    gpt2/h25/norm_2/g\n",
            "    gpt2/h25/norm_2/b\n",
            "    gpt2/h25/mlp/conv1d_main/c_proj/bias\n",
            "    gpt2/h26/norm_1/g\n",
            "    gpt2/h26/norm_1/b\n",
            "    gpt2/h26/attn/compute_output_bias/o_b\n",
            "    gpt2/h26/norm_2/g\n",
            "    gpt2/h26/norm_2/b\n",
            "    gpt2/h26/mlp/conv1d_main/c_proj/bias\n",
            "    gpt2/h27/norm_1/g\n",
            "    gpt2/h27/norm_1/b\n",
            "    gpt2/h27/attn/compute_output_bias/o_b\n",
            "    gpt2/h27/norm_2/g\n",
            "    gpt2/h27/norm_2/b\n",
            "    gpt2/h27/mlp/conv1d_main/c_proj/bias\n",
            "    gpt2/h28/norm_1/g\n",
            "    gpt2/h28/norm_1/b\n",
            "    gpt2/h28/attn/compute_output_bias/o_b\n",
            "    gpt2/h28/norm_2/g\n",
            "    gpt2/h28/norm_2/b\n",
            "    gpt2/h28/mlp/conv1d_main/c_proj/bias\n",
            "    gpt2/h29/norm_1/g\n",
            "    gpt2/h29/norm_1/b\n",
            "    gpt2/h29/attn/compute_output_bias/o_b\n",
            "    gpt2/h29/norm_2/g\n",
            "    gpt2/h29/norm_2/b\n",
            "    gpt2/h29/mlp/conv1d_main/c_proj/bias\n",
            "    gpt2/h30/norm_1/g\n",
            "    gpt2/h30/norm_1/b\n",
            "    gpt2/h30/attn/compute_output_bias/o_b\n",
            "    gpt2/h30/norm_2/g\n",
            "    gpt2/h30/norm_2/b\n",
            "    gpt2/h30/mlp/conv1d_main/c_proj/bias\n",
            "    gpt2/h31/norm_1/g\n",
            "    gpt2/h31/norm_1/b\n",
            "    gpt2/h31/attn/compute_output_bias/o_b\n",
            "    gpt2/h31/norm_2/g\n",
            "    gpt2/h31/norm_2/b\n",
            "    gpt2/h31/mlp/conv1d_main/c_proj/bias\n",
            "    gpt2/ln_f/g\n",
            "    gpt2/ln_f/b\n",
            "Variable stacked/gpt2/h8/norm_2/g                                     size 130560       slice_size 65280        Shape[stacked=51, embd=2560]                                \n",
            "    gpt2/h8/norm_2/g\n",
            "    gpt2/h8/norm_2/b\n",
            "    gpt2/h8/mlp/conv1d_main/c_proj/bias\n",
            "    gpt2/h9/norm_1/g\n",
            "    gpt2/h9/norm_1/b\n",
            "    gpt2/h9/attn/compute_output_bias/o_b\n",
            "    gpt2/h9/norm_2/g\n",
            "    gpt2/h9/norm_2/b\n",
            "    gpt2/h9/mlp/conv1d_main/c_proj/bias\n",
            "    gpt2/h10/norm_1/g\n",
            "    gpt2/h10/norm_1/b\n",
            "    gpt2/h10/attn/compute_output_bias/o_b\n",
            "    gpt2/h10/norm_2/g\n",
            "    gpt2/h10/norm_2/b\n",
            "    gpt2/h10/mlp/conv1d_main/c_proj/bias\n",
            "    gpt2/h11/norm_1/g\n",
            "    gpt2/h11/norm_1/b\n",
            "    gpt2/h11/attn/compute_output_bias/o_b\n",
            "    gpt2/h11/norm_2/g\n",
            "    gpt2/h11/norm_2/b\n",
            "    gpt2/h11/mlp/conv1d_main/c_proj/bias\n",
            "    gpt2/h12/norm_1/g\n",
            "    gpt2/h12/norm_1/b\n",
            "    gpt2/h12/attn/compute_output_bias/o_b\n",
            "    gpt2/h12/norm_2/g\n",
            "    gpt2/h12/norm_2/b\n",
            "    gpt2/h12/mlp/conv1d_main/c_proj/bias\n",
            "    gpt2/h13/norm_1/g\n",
            "    gpt2/h13/norm_1/b\n",
            "    gpt2/h13/attn/compute_output_bias/o_b\n",
            "    gpt2/h13/norm_2/g\n",
            "    gpt2/h13/norm_2/b\n",
            "    gpt2/h13/mlp/conv1d_main/c_proj/bias\n",
            "    gpt2/h14/norm_1/g\n",
            "    gpt2/h14/norm_1/b\n",
            "    gpt2/h14/attn/compute_output_bias/o_b\n",
            "    gpt2/h14/norm_2/g\n",
            "    gpt2/h14/norm_2/b\n",
            "    gpt2/h14/mlp/conv1d_main/c_proj/bias\n",
            "    gpt2/h15/norm_1/g\n",
            "    gpt2/h15/norm_1/b\n",
            "    gpt2/h15/attn/compute_output_bias/o_b\n",
            "    gpt2/h15/norm_2/g\n",
            "    gpt2/h15/norm_2/b\n",
            "    gpt2/h15/mlp/conv1d_main/c_proj/bias\n",
            "    gpt2/h16/norm_1/g\n",
            "    gpt2/h16/norm_1/b\n",
            "    gpt2/h16/attn/compute_output_bias/o_b\n",
            "    gpt2/h16/norm_2/g\n",
            "    gpt2/h16/norm_2/b\n",
            "    gpt2/h16/mlp/conv1d_main/c_proj/bias\n",
            "Trainable Variables            count: 200     Total size: 2651307520       Total slice_size: 381853440      \n",
            "All Variables                  count: 200     Total size: 2651307520       Total slice_size: 381853440      \n",
            "Counters:\n",
            "allreduce: 1.68e+10\n",
            " allreduce/[0]: 5.37e+09\n",
            "  allreduce/[0]/einsum_op: 5.37e+09\n",
            " allreduce/[1]: 1.14e+10\n",
            "  allreduce/[1]/einsum_op: 1.14e+10\n",
            "  allreduce/[1]/reduce_op: 1.9e+07\n",
            "einsum: 3.19e+13\n",
            "einsum_unique: 2.48e+13\n",
            "output: 2.02e+11\n",
            " output/AddOperation: 5.68e+10\n",
            " output/BinaryOpWithBroadcasting: 6.88e+08\n",
            " output/BroadcastOperation: 5.4e+09\n",
            " output/ConcatOperation: 2.69e+09\n",
            " output/Constant: 2.62e+05\n",
            " output/EinsumOperation: 5.59e+10\n",
            " output/ImportOperation: 1.31e+05\n",
            " output/OneHotOperation: 3.33e+09\n",
            " output/RangeOperation: 3.19e+05\n",
            " output/ReduceOperation: 2.95e+07\n",
            " output/ReshapeOperation: 1.01e+10\n",
            " output/ScalarAddOperation: 5.37e+09\n",
            " output/ScalarMultiplyOperation: 1.89e+10\n",
            " output/ShiftOperation: 1.34e+09\n",
            " output/SlicewiseOperation: 2.73e+10\n",
            " output/StackedVariable: 2.64e+06\n",
            " output/StopGradient: 8.05e+09\n",
            " output/UnstackOperation: 2.64e+06\n",
            " output/Variable: 3.05e+09\n",
            " output/WhileLoopOperation: 2.68e+09\n",
            "output_unique: 1.09e+11\n",
            " output_unique/AddOperation: 3.1e+10\n",
            " output_unique/BinaryOpWithBroadcasting: 8.81e+07\n",
            " output_unique/BroadcastOperation: 5.38e+09\n",
            " output_unique/ConcatOperation: 1.34e+09\n",
            " output_unique/Constant: 3.28e+04\n",
            " output_unique/EinsumOperation: 2.53e+10\n",
            " output_unique/ImportOperation: 1.64e+04\n",
            " output_unique/OneHotOperation: 4.16e+08\n",
            " output_unique/RangeOperation: 4.1e+04\n",
            " output_unique/ReduceOperation: 1.16e+07\n",
            " output_unique/ReshapeOperation: 5.37e+09\n",
            " output_unique/ScalarAddOperation: 2.68e+09\n",
            " output_unique/ScalarMultiplyOperation: 8.75e+09\n",
            " output_unique/ShiftOperation: 6.71e+08\n",
            " output_unique/SlicewiseOperation: 1.75e+10\n",
            " output_unique/StackedVariable: 8.24e+05\n",
            " output_unique/StopGradient: 6.71e+09\n",
            " output_unique/UnstackOperation: 8.24e+05\n",
            " output_unique/Variable: 2.65e+09\n",
            " output_unique/WhileLoopOperation: 1.34e+09\n",
            "variables: 2.65e+09\n",
            " variables/trainable: 2.65e+09\n",
            "Done calling model_fn.\n",
            "TPU job name worker\n",
            "Graph was finalized.\n",
            "Restoring parameters from gs://test-bucket-neo/GPT3_2-7B/model.ckpt-400000\n",
            "Running local_init_op.\n",
            "Done running local_init_op.\n",
            "From /usr/local/lib/python3.7/dist-packages/tensorflow_estimator/python/estimator/tpu/tpu_estimator.py:840: Variable.load (from tensorflow.python.ops.variables) is deprecated and will be removed in a future version.\n",
            "Instructions for updating:\n",
            "Prefer Variable.assign which has equivalent behavior in 2.X.\n",
            "Starting infeed thread controller.\n",
            "Starting outfeed thread controller.\n",
            "Initialized dataset iterators in 0 seconds\n",
            "Before copy master to slices.\n",
            "Done with copy master to slices.\n",
            "Enqueue next (1) batch(es) of data to infeed.\n",
            "Dequeue next (1) batch(es) of data from outfeed.\n",
            "Outfeed finished for iteration (0, 0)\n",
            "======================================== SAMPLE 0 ========================================\n",
            "\n",
            "\n",
            "class GPT(nn.Module):\n",
            "    \"\"\"  the full GPT language model, with a context size of block_size \"\"\"\n",
            "\n",
            "    def __init__(self, config):\n",
            "        super().__init__()\n",
            "\n",
            "        # input embedding stem\n",
            "        self.tok_emb = nn.Embedding(config.vocab_size, config.n_embd)\n",
            "        self.pos_emb = nn.Parameter(torch.zeros(1, config.block_size, config.n_embd))\n",
            "        self.drop = nn.Dropout(config.embd_pdrop)\n",
            "        # transformer\n",
            "        self.blocks = nn.Sequential(*[Block(config) for _ in range(config.n_layer)])\n",
            "        # decoder head\n",
            "        self.ln_f = nn.LayerNorm(config.n_embd)\n",
            "        self.head = nn.Linear(config.n_embd, config.vocab_size, bias=False)\n",
            "\n",
            "        self.block_size = config.block_size\n",
            "        self.apply(self._init_weights)\n",
            "\n",
            "        logger.info(\"number of parameters: %e\", sum(p.numel() for p in self.parameters()))\n",
            "\n",
            "    def forward(self, input):\n",
            "        \"\"\" return gpt from position embedding (embedding for position and context)\"\"\"\n",
            "        return GPT(input, self.pos_emb, self.tok_emb, self.drop, self.ln_f, self.head)\n",
            "\n",
            "    def get_type_log_probability(self, input, target, p_type):\n",
            "        \"\"\" get negative log-likelihood for the current probability (p_type)\n",
            "        \"\"\"\n",
            "        embedding = self.tok_emb(input)\n",
            "        return nn.log_softmax(embedding, dim=1) / sum(input.size(1) for input in input)\n",
            "\n",
            "\n",
            "def update_parameters_for_training(model, input_length, targets,\n",
            "                                   target_length, context_size, apply_onehot=False):\n",
            "    \"\"\" update parameters after re-training or training in 2-shot.\n",
            "\n",
            "            model.set_params(...)..returns(model_post_training)\n",
            "            model_post_training: the updated model\n",
            "    \"\"\"\n",
            "    if not model.sampler:\n",
            "        model.reset_params()\n",
            "    elif model.sampler.get_seed()!= 0 or limit_sampled_sequences(model.sampler.get_seed()):\n",
            "        if apply_onehot:\n",
            "            model.reset_params()\n",
            "\n",
            "    loss = nn.BCELoss()\n",
            "    model.loss = loss\n",
            "    model.disp = model.disp + (1.0 - model.disp) * model.log_prob(input_length, target_length)\n",
            "    model.mean_disp = model.disp\n",
            "    model.mean_pos = model.pos\n",
            "    score = model.log_prob(target_length, target_length)\n",
            "    if (input_length == target_length):\n",
            "        # single shot - ignore intro, ilux and outros\n",
            "        if apply_onehot:\n",
            "            target[0][0] = '%s %s' % (target_length, target_length)\n",
            "        else:\n",
            "            target[0][0] = '%s %d' % (target_length, target_length)\n",
            "    else:\n",
            "        # 2-shot - batch one of the input embedding, multi-shot - batch by sequence.\n",
            "        targets = torch.cat([tuple([chr[0] if chr[0] in target[0] else '?' for chr in target])\n",
            "                             for target in target_length], 2)\n",
            "        target_length = len(targets)\n",
            "\n",
            "    pos_emb = self.pos_emb(input)\n",
            "    tok_emb = self.tok_emb(input)\n",
            "    drop = self.drop(input)\n",
            "\n",
            "    head_drop = tok_emb.nonlinearity * drop\n",
            "\n",
            "    for ln_f in self.ln_f:\n",
            "        self.ln_f = nn.LayerNorm(self.n_embd)\n",
            "        self.ln_f.weight.data.zero_()\n",
            "        self.ln_f.bias.data.zero_()\n",
            "\n",
            "    for block in self.blocks:\n",
            "        self.head_drop.weight.data.zero_()\n",
            "        self.head_drop.bias.data.zero_()\n",
            "\n",
            "    for i in range(self.n_layer):\n",
            "        param_tuple = (i, block, head_drop, len(targets), config.init_lstm_c)\n",
            "        t_pos, t_targets, _ = torch.max(target, param_tuple[0], param_tuple[1])\n",
            "\n",
            "        # fast threshold -> 1 will be equal to target, non-zero will not be all 0\n",
            "        t_pos = t_pos if t_pos == 0 else 1\n",
            "        t_targets = t_targets if t_targets == 0 else 1\n",
            "        self.pixel_to_pos = target[t_pos:t_pos+1]\n",
            "\n",
            "        # linear decrease\n",
            "        self.disp_drop = tok_emb.nonlinearity * self.drop(t_pos)\n",
            "        self.disp_drop.weight.data.zero_()\n",
            "\n",
            "        self.weight_reset = torch.zeros(2)\n",
            "        self.bias_reset = torch.zeros(2)\n",
            "\n",
            "        emb_tok_id = model.pixel_to_pos\n",
            "        weight_last = Embedding(1, config.n_embd)\n",
            "        self.q = weight_last(emb_tok_id)\n",
            "        self.q_last = weight_last(self.q)\n",
            "\n",
            "        mask_name = '%s/%s/%s_%d' % (config.tok_id, config.pos_id, tok_emb.size(), pos_emb.size())\n",
            "        self.loss_mask = nn.LogSoftmax(dim=1)\n",
            "        self.loss_state = nn.Linear(config.n_embd+num_tok_c, config.n_embd)\n",
            "        self.target_to_pos = per_target_pos(target, param_tuple[0], param_tuple[1], self.head_drop, label=targets)\n",
            "        self.loss_target_to_pos = per_target_pos(target, param_tuple[0], param_tuple[1], self.head_drop, label=targets)\n",
            "        self.mask_loss_name = \"loss_mask\"\n",
            "        target_to_pos = nn.LogSoftmax(dim=1)\n",
            "        for i in\n",
            "\n",
            "================================================================================\n",
            "\n",
            "======================================== SAMPLE 1 ========================================\n",
            "\n",
            "\n",
            "class GPT(nn.Module):\n",
            "    \"\"\"  the full GPT language model, with a context size of block_size \"\"\"\n",
            "\n",
            "    def __init__(self, config):\n",
            "        super().__init__()\n",
            "\n",
            "        # input embedding stem\n",
            "        self.tok_emb = nn.Embedding(config.vocab_size, config.n_embd)\n",
            "        self.pos_emb = nn.Parameter(torch.zeros(1, config.block_size, config.n_embd))\n",
            "        self.drop = nn.Dropout(config.embd_pdrop)\n",
            "        # transformer\n",
            "        self.blocks = nn.Sequential(*[Block(config) for _ in range(config.n_layer)])\n",
            "        # decoder head\n",
            "        self.ln_f = nn.LayerNorm(config.n_embd)\n",
            "        self.head = nn.Linear(config.n_embd, config.vocab_size, bias=False)\n",
            "\n",
            "        self.block_size = config.block_size\n",
            "        self.apply(self._init_weights)\n",
            "\n",
            "        logger.info(\"number of parameters: %e\", sum(p.numel() for p in self.parameters()))\n",
            "        # normalization\n",
            "        self.weight_gpu = nn.Parameter(torch.Tensor(self.weight.size(1).num()))\n",
            "        self.bias_gpu = nn.Parameter(torch.zeros(1).type(torch.float32))\n",
            "\n",
            "    def _init_weights(self):\n",
            "        num_b = self.head.weight.size(1)\n",
            "        drop_b = self.head.bias.size(0)\n",
            "        self.weight = nn.Parameter(torch.Tensor(num_b, drop_b))\n",
            "        self.bias = nn.Parameter(torch.zeros(drop_b).type(torch.float32))\n",
            "\n",
            "    def forward(self, H, g, X): \n",
            "        \"\"\"  - token-level feed forward\n",
            "        - Embed Otherwise\n",
            "            (f) g is ignored for the embeddings, and this is only used to save the\n",
            "                gpt translation encoder memory.\n",
            "        \"\"\"\n",
            "        output = {}\n",
            "        if self.head.keep:\n",
            "            X_top = X_top.view(-1, self.emb_size, 1)\n",
            "\n",
            "            for j in range(self.head.nheads):\n",
            "                dX = X_top[:, 0]\n",
            "                dX = dX.transpose(0, 1)[0]\n",
            "                dX /= X_top[:, 1].sum(1, keepdim=1)[0]\n",
            "                X_top = self.head(dX)\n",
            "                dX = X_top[:, 0]\n",
            "                dX = dX.transpose(0, 1)[0]\n",
            "                dX /= X_top[:, 1].sum(0, keepdim=1)[0]\n",
            "                X_top = self.head(dX)\n",
            "\n",
            "            for i in range(self.head.n_layer):\n",
            "                H = torch.cat([H, self.ln_f(H)[0]]).view(-1)\n",
            "                if self.drop > 0:\n",
            "                    g = torch.zeros_like(H.long()).float()\n",
            "                else:\n",
            "                    g = H.long()\n",
            "\n",
            "                g = g.transpose(1, 2).contiguous().view(-1, g.size(1))\n",
            "                if self.apply_del_emb:\n",
            "                    output[j] = g.transpose(0, 1)\n",
            "                else:\n",
            "                    output[j] = self.head(g)\n",
            "\n",
            "                H = H.transpose(0, 1)\n",
            "        else:\n",
            "            X_top = X_top.view(-1, self.emb_size, 1)\n",
            "            for j in range(self.head.nheads):\n",
            "                dX = X_top[:, 0]\n",
            "                dX = dX.transpose(0, 1)[0]\n",
            "                dX /= X_top[:, 1].sum(1, keepdim=1)[0]\n",
            "                X_top = self.head(dX)\n",
            "                dX = X_top[:, 0]\n",
            "                dX = dX.transpose(0, 1)[0]\n",
            "                dX /= X_top[:, 1].sum(0, keepdim=1)[0]\n",
            "                X_top = self.head(dX)\n",
            "\n",
            "            for i in range(self.head.n_layer):\n",
            "                g = torch.cat([self.ln_f(H)[0], g])[0]\n",
            "                if self.drop > 0:\n",
            "                    g = torch.zeros_like(g).float()\n",
            "                else:\n",
            "                    g = g.transpose(1, 2).contiguous().view(-1, g.size(1))\n",
            "                if self.apply_del_emb:\n",
            "                    output[j] = g.transpose(0, 1)\n",
            "                else:\n",
            "                    output[j] = self.head(g)\n",
            "\n",
            "        output = output[\"h\"].transpose(0, 1)\n",
            "        return output\n",
            "\n",
            "\n",
            "\n",
            "================================================================================\n",
            "\n",
            "======================================== SAMPLE 2 ========================================\n",
            "\n",
            "\n",
            "class GPT(nn.Module):\n",
            "    \"\"\"  the full GPT language model, with a context size of block_size \"\"\"\n",
            "\n",
            "    def __init__(self, config):\n",
            "        super().__init__()\n",
            "\n",
            "        # input embedding stem\n",
            "        self.tok_emb = nn.Embedding(config.vocab_size, config.n_embd)\n",
            "        self.pos_emb = nn.Parameter(torch.zeros(1, config.block_size, config.n_embd))\n",
            "        self.drop = nn.Dropout(config.embd_pdrop)\n",
            "        # transformer\n",
            "        self.blocks = nn.Sequential(*[Block(config) for _ in range(config.n_layer)])\n",
            "        # decoder head\n",
            "        self.ln_f = nn.LayerNorm(config.n_embd)\n",
            "        self.head = nn.Linear(config.n_embd, config.vocab_size, bias=False)\n",
            "\n",
            "        self.block_size = config.block_size\n",
            "        self.apply(self._init_weights)\n",
            "\n",
            "        logger.info(\"number of parameters: %e\", sum(p.numel() for p in self.parameters()))\n",
            "\n",
            "        self.optimizer = optim.Adam(\n",
            "            self.head,\n",
            "            parameters_ub=self.parameters(),\n",
            "            lam=config.initial_learning_rate\n",
            "        )\n",
            "\n",
            "    def forward(self, input_text):\n",
            "        \"\"\" the overall model.co: forward pass \"\"\"\n",
            "\n",
            "        limit = self.head.output_size(0)\n",
            "        head = self.head\n",
            "        attn = self.head.weight\n",
            "        # tagwith = self.head.weight\n",
            "\n",
            "        block = self.blocks[:,0][self.block_size:,:]\n",
            "        forward_attn = block(attn)\n",
            "        forward_text = forward_attn + input_text\n",
            "        forward_text = conv_block(forward_text)\n",
            "        forward_text = forward_linear(forward_text)\n",
            "        forward_text = forward_linear(forward_linear(forward_text))\n",
            "\n",
            "        lower_attn = (self.tok_emb(forward_text)).sum(1, keepdim=True)\n",
            "        # lower_attn = self.tok_emb(1)\n",
            "\n",
            "        #rnn_basic_block1 = forward_attn[:self.head.layers_[0].output_size(1), self.head.layers_[0].output_size(0),:].view(1, 1, self.block_size, -1)\n",
            "        #rnn_basic_block1 = rnn_tok(forward_text[:self.head.layers_[0].output_size(1), self.head.layers_[0].output_size(0), :].transpose(1, 0, 2) + forward_text[self.head.layers_[0].output_size(1), self.head.layers_[0].output_size(0),:])\n",
            "        #rnn_basic_block1_drop = nn.Dropout(config.drop_rate)\n",
            "        #print(rnn_basic_block1_drop.shape)\n",
            "        #print(rnn_basic_block1.weight.shape)\n",
            "        # post_drop = rnn_basic_block1_drop.view(1, 1, self.block_size, 1)\n",
            "        # rnn_part_block1 = forward_attn[self.head.layers_[0].output_size(0), self.head.layers_[0].output_size(1), :].view(1, self.block_size, self.head.n_layer)\n",
            "        # post_drop = post_drop + rnn_part_block1.weight.view(self.block_size, self.head.n_layer, 1) + rnn_part_block1.bias.view(self.head.n_layer, 1, 1).expand_as(rnn_part_block1)\n",
            "\n",
            "        #rnn_part_block1 = (head(head(numpy.squeeze(forward_text), 1)))[:self.head.layers_[0].output_size(0), self.head.layers_[0].output_size(1), :].view(self.block_size, self.head.layers_[0].n_layer, -1)\n",
            "        #rnn_part_block1 = rnn_basic_block1_drop + post_drop + rnn_part_block1.weight.view(self.block_size, self.head.n_layer, 1) + rnn_part_block1.bias.view(self.head.n_layer, 1, 1).expand_as(rnn_part_block1)\n",
            "\n",
            "        lower_rnn_text = (head(head(numpy.squeeze(forward_text), 1)))[self.head.layers_[0].output_size(0), self.head.layers_[0].output_size(1), :].view(self.head.n_layer, self.block_size, -1)\n",
            "        lower_rnn = rnn_tok(lower_rnn_text)\n",
            "\n",
            "        #attention_layers = self.attention_layer\n",
            "        #context_attention_layers = self.context_attention_layer\n",
            "        #attn_context_layers = self.attention_layer + self.context_attention_layer\n",
            "        #attn_context_layers = self.attention_layer\n",
            "        #propagation_layers = self.proper_layer + self.context_attention_layer\n",
            "\n",
            "        return lower_attn + lower_rnn_text + lower_rnn\n",
            "\n",
            "    def backward(self, grad_output, grad_input):\n",
            "        \"\"\" the model.co: backward pass \"\"\"\n",
            "\n",
            "        grad_weight = torch.matmul(grad_output[self.head.layers_[0].output_size(0)], grad_input.contiguous())\n",
            "        return grad_weight.view(batch_size, -1, self.head.n_layer), grad_weight.view(batch_size, -1, self.head.n_layer)\n",
            "\n",
            "    def clip_gradient(self, grad_input):\n",
            "        \"\"\" clip gradient \"\"\"\n",
            "        logger.warning(\"clip_gradient: clip(grad_input, 0.0 - 1.0)\")\n",
            "        return grad_input.clamp(0.0 - 1.0).detach().cpu().numpy()\n",
            "\n",
            "    def _get_cell(self, name):\n",
            "        if self.args.tied_base_model:\n",
            "            return self.head.layers_[name].n_op\n",
            "\n",
            "        return self.head.layers_[name]\n",
            "\n",
            "    def _get_head(self, head_name):\n",
            "        if self.args.tied_base_model:\n",
            "            return head_name\n",
            "\n",
            "        return self.head.n_layer\n",
            "\n",
            "\n",
            "    def forward_gpt_cell(self, head):\n",
            "        \"\"\" the forward pass of the gpt\n",
            "\n",
            "================================================================================\n",
            "\n",
            "======================================== SAMPLE 3 ========================================\n",
            "\n",
            "\n",
            "class GPT(nn.Module):\n",
            "    \"\"\"  the full GPT language model, with a context size of block_size \"\"\"\n",
            "\n",
            "    def __init__(self, config):\n",
            "        super().__init__()\n",
            "\n",
            "        # input embedding stem\n",
            "        self.tok_emb = nn.Embedding(config.vocab_size, config.n_embd)\n",
            "        self.pos_emb = nn.Parameter(torch.zeros(1, config.block_size, config.n_embd))\n",
            "        self.drop = nn.Dropout(config.embd_pdrop)\n",
            "        # transformer\n",
            "        self.blocks = nn.Sequential(*[Block(config) for _ in range(config.n_layer)])\n",
            "        # decoder head\n",
            "        self.ln_f = nn.LayerNorm(config.n_embd)\n",
            "        self.head = nn.Linear(config.n_embd, config.vocab_size, bias=False)\n",
            "\n",
            "        self.block_size = config.block_size\n",
            "        self.apply(self._init_weights)\n",
            "\n",
            "        logger.info(\"number of parameters: %e\", sum(p.numel() for p in self.parameters()))\n",
            "        logger.info(\"images size: %e\", config.images_len)\n",
            "        logger.info(\"embedding size: %e\", config.embedding_size)     \n",
            "\n",
            "        self.vocab_size = config.vocab_size\n",
            "        self.hidden_size = config.hidden_size\n",
            "        self.n_layer = config.n_layer\n",
            "        self.block_size = config.block_size\n",
            "        self.cell_dim = config.cell_dim\n",
            "        self.n_embd = config.n_embd\n",
            "        self.n_embd = config.n_embd\n",
            "        self.embd_pdrop = config.embd_pdrop\n",
            "        self.n_batch = config.n_batch\n",
            "        self.n_dembd = config.n_dembd\n",
            "        self.101k_embd = config.101k_embd\n",
            "        self.shotting_dist = config.shotting_dist\n",
            "        self.dropout = config.embd_pdrop\n",
            "\n",
            "        # init variables\n",
            "        self._init_weights()\n",
            "\n",
            "    def _init_weights(self):\n",
            "        for layer in self.blocks:\n",
            "            for cell in layer:\n",
            "                param_init = cell.init_weights()\n",
            "                self.parameters()[layer][cell] = param_init.assign(param_init)\n",
            "\n",
            "    def forward(self, x, gpt_emb, gpt_state, gpt_emb_dim, gpt_state_dim):\n",
            "        \"\"\"  a forward pass for language model derivations\n",
            "\n",
            "            input, latent and context embeddings of convolutional layers as well as the entity embedding to obtain the\n",
            "            topic-embedding applied to the gpt entity embedding to generate the knowledge graph representation\n",
            "            gpt latent representations are then transformed into some vector representation\n",
            "\n",
            "            latent representation is then used as input to the decoder head, to produce the gpt entity representation\n",
            "\n",
            "            finally, the gpt entity representation is used as input to the decoder head, to produce the gpt latent representation on top of which\n",
            "            the knowledge graph representation is constructed\n",
            "        \"\"\"\n",
            "\n",
            "        n_ctx = len(x)\n",
            "        x_bn = x.nonzero()[0]/n_ctx\n",
            "        latent_bn = x_bn.nonzero()[0]/n_ctx\n",
            "        cv_emb = self.tok_emb(x_bn)\n",
            "        # consider the entity embedding to get the gpt latent representation\n",
            "        entity_mask = self.apply(gpt_emb_dim) if self.embd else 0\n",
            "        latent = self.apply(gpt_state_dim)\n",
            "        latent = latent * gpt_emb + entity_mask * gpt_state + self.drop\n",
            "\n",
            "        self.ln_f.weight.data.fill_(1.0)\n",
            "        self.ln_f.bias.data.zero_()\n",
            "        self.ln_f.weight.data[0].copy_(self.tok_emb)\n",
            "        self.ln_f.bias.data[0].copy_(self.pos_emb)\n",
            "        mlp = nn.Linear(config.hidden_size, config.n_embd)\n",
            "        mlp.bias.data[0].copy_(self.hidden_size)\n",
            "        self.ln_f.weight.data[0].copy_(mlp.weight.data)\n",
            "        # get the gpt latent representation on top of which the knowledge graph\n",
            "        ln_gpt_emb = self.apply(gpt_emb_dim)\n",
            "        # ln_gpt_emb = logits.sample(self.shotting_dist)\n",
            "        # ln_gpt_emb_shape = [1]\n",
            "        # gpt_ln_emb.data[0].copy_(ln_gpt_emb.data[0])\n",
            "        # gpt_ln_emb_shape = [0]\n",
            "        # gpt_gpt_emb = gpt_ln_emb.gather([0], gpt_ln_emb.shape)\n",
            "        # get the gpt latent representation to be used as the starting latent embedding of the decoder\n",
            "        ln_src_emb = gpt_ln_emb\n",
            "        ln_state = gpt_ln_emb.gather([0], ln_gpt_emb.shape)\n",
            "\n",
            "        # get the context and latent embedding representation of the entire input\n",
            "        # x_ext = x_bn[latex_str].squeeze()\n",
            "        self.apply(n)\n",
            "        # get the context representation used to decode the embedded gpt representation\n",
            "        x_ext = x_bn[latex_str].squeeze()\n",
            "        x_ext = x_ext.transpose(1, 0)\n",
            "        x_ext = F.relu(self.apply(x_ext))\n",
            "        x_ext_bn = x_ext.transpose(1, 0)\n",
            "        # x_ext_bn = x_ext_bn.transpose(1, 0)\n",
            "        # initialize the decoder hidden state\n",
            "        ln_src_emb, diff_emb = collections.defaultdict(list), []\n",
            "        for i, i_emb in enumerate(self.ln_f):\n",
            "            i_blk = int(self.block_size*(i+1))\n",
            "            mlp = nn.Linear(context_embedding_dim, n\n",
            "\n",
            "================================================================================\n",
            "\n",
            "Enqueue next (1) batch(es) of data to infeed.\n",
            "Dequeue next (1) batch(es) of data from outfeed.\n",
            "Outfeed finished for iteration (1, 0)\n",
            "Stop infeed thread controller\n",
            "Shutting down InfeedController thread.\n",
            "InfeedController received shutdown signal, stopping.\n",
            "Infeed thread finished, shutting down.\n",
            "infeed marked as finished\n",
            "Stop output thread controller\n",
            "Shutting down OutfeedController thread.\n",
            "OutfeedController received shutdown signal, stopping.\n",
            "Outfeed thread finished, shutting down.\n",
            "outfeed marked as finished\n",
            "Shutdown TPU system.\n",
            "prediction_loop marked as finished\n",
            "prediction_loop marked as finished\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "nE9VImzHaI0z"
      },
      "source": [
        "# Evaluating the model"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "XGGbkgaFfp6f"
      },
      "source": [
        "This section assumes you are using a pretrained model and relies on variables created in the `Pretrained model` section."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "I45yUIpbaLUJ"
      },
      "source": [
        "## Wikitext"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "zwBDB9U2keFV"
      },
      "source": [
        "Download the wikitext test set:\n"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "uuugiBmJaNxf"
      },
      "source": [
        "wikitext103_src = \"https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-103-raw-v1.zip\"\n",
        "!wget $wikitext103_src\n",
        "!unzip wikitext-103-raw-v1.zip"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "J5wf3QWKkhZt"
      },
      "source": [
        "Tokenize and upload to bucket:\n"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "6mo8UUtDdctH"
      },
      "source": [
        "\n",
        "!mkdir wikitext\n",
        "!mv /content/GPTNeo/wikitext-103-raw/wiki.test.raw wikitext/wikitext_test.txt\n",
        "\n",
        "# Tokenize Data\n",
        "!python data/create_tfrecords.py --input_dir wikitext --name wikitext --files_per 1000 --output_dir wikitext_tokenized --write_dataset_config --processes 1 --wikitext-detokenize\n",
        "\n",
        "# copy the data to your bucket\n",
        "if not path_to_cloud_bucket.endswith('/'):\n",
        "  path_to_cloud_bucket += '/'\n",
        "copy_loc = path_to_cloud_bucket \n",
        "!gsutil -m cp -r wikitext_tokenized $copy_loc\n",
        "!gsutil ls $path_to_cloud_bucket"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "GE84TUd1fAzf"
      },
      "source": [
        "Now make a dataset config that points to the tokenized wikitext data:"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "Z5UU7DQeeY0S"
      },
      "source": [
        "%%writefile configs/dataset_configs/wikitext.json\n",
        "\n",
        "{\n",
        "  \"path\": \"\",\n",
        "  \"eval_path\": \"gs://test-bucket-neo/wikitext_tokenized/*.tfrecords\",\n",
        "  \"n_vocab\": 50256,\n",
        "  \"tokenizer_is_pretrained\": true,\n",
        "  \"tokenizer_path\": \"gpt2\",\n",
        "  \"eos_id\": 50256,\n",
        "  \"padding_id\": 50257\n",
        "}\n"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "egvdwIOqfFER"
      },
      "source": [
        "And update your model config to point to that dataset:\n"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "cellView": "form",
        "id": "AtdoIFMgfOe8"
      },
      "source": [
        "# @title Modify config for wikitext. \n",
        "  \n",
        "import json\n",
        "from pprint import pprint\n",
        "\n",
        "batch_size = 8 #@param {type:\"integer\"}\n",
        "assert pretrained_model is not None\n",
        "with open(f'configs/{pretrained_model}.json', 'r') as f:\n",
        "  data = json.load(f)\n",
        "  pprint(data)\n",
        "  dset_val = [[\"wikitext\", None, None, None]]\n",
        "  mods = {\n",
        "          \"datasets\": dset_val,\n",
        "          \"eval_steps\": 139 // batch_size,\n",
        "          \"train_batch_size\": batch_size,\n",
        "          \"eval_batch_size\": batch_size,\n",
        "        }\n",
        "  data.update(mods)\n",
        "  print('\\n--->\\n')\n",
        "  pprint(data)\n",
        "  with open(f'configs/{pretrained_model}.json', 'w') as outfile:\n",
        "    json.dump(data, outfile, indent=2)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "U2d5eTHEg6Xj"
      },
      "source": [
        "Now run model in eval mode over tokenized data:"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "s1Uz3PXzg5Pm"
      },
      "source": [
        "!python3 main.py --eval --tpu colab --model $pretrained_model"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "9dbkPVcMhVaR"
      },
      "source": [
        "## Lambada\n",
        "\n",
        "Lambada eval is built into the codebase and can be run by adding a field to your model config"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "cellView": "form",
        "id": "z4FJXOlJiEYo"
      },
      "source": [
        "# @title Modify config for Lambada. \n",
        "  \n",
        "import json\n",
        "from pprint import pprint\n",
        "\n",
        "batch_size = 8 #@param {type:\"integer\"}\n",
        "assert pretrained_model is not None\n",
        "with open(f'configs/{pretrained_model}.json', 'r') as f:\n",
        "  data = json.load(f)\n",
        "  mods = {\n",
        "          \"datasets\": dset_val,\n",
        "          \"eval_steps\": 0,\n",
        "          \"train_batch_size\": batch_size,\n",
        "          \"eval_batch_size\": batch_size,\n",
        "          \"eval_tasks\": [\"lambada\"]\n",
        "        }\n",
        "  data.update(mods)\n",
        "  print('\\n--->\\n')\n",
        "  pprint(data)\n",
        "  with open(f'configs/{pretrained_model}.json', 'w') as outfile:\n",
        "    json.dump(data, outfile, indent=2)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Upp-bGMriVPK"
      },
      "source": [
        "Now run the eval:"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "OOA1YZDRiUhN"
      },
      "source": [
        "!python3 main.py --eval --tpu colab --model $pretrained_model"
      ],
      "execution_count": null,
      "outputs": []
    }
  ]
}
